python衡量數據分布的相似度/距離(KL/JS散度)
- 2019 年 10 月 30 日
- 筆記
背景
很多場景需要考慮數據分布的相似度/距離:比如確定一個正態分布是否能夠很好的描述一個群體的身高(正態分布生成的樣本分布應當與實際的抽樣分布接近),或者一個分類演算法是否能夠很好地區分樣本的特徵(在兩個分類下的數據分布的差異應當比較大)。
(例子:上圖來自 OpenAI的 Radford A , Jozefowicz R , Sutskever I . Learning to Generate Reviews and Discovering Sentiment[J]. 2017. 他們發現他們訓練的深度神經網路中有一個單獨的神經元就對正負情感的區分度相當良好。)
上圖可以直接看出明顯的分布區別,但是能夠衡量這種分布的距離更便於多種方法間的比較。KL/JS散度就是常用的衡量數據概率分布的數值指標,可以看成是數據分布的一種「距離」,關於它們的理論基礎可以在網上找到很多參考,這裡只簡要給出公式和性質,還有程式碼實現:
KL散度
有時也稱為相對熵,KL距離。對於兩個概率分布P、Q,二者越相似,KL散度越小。
- KL散度滿足非負性
- KL散度是不對稱的,交換P、Q的位置將得到不同結果。
python3程式碼:
import numpy as np import scipy.stats p=np.asarray([0.65,0.25,0.07,0.03]) q=np.array([0.6,0.25,0.1,0.05]) def KL_divergence(p,q): return scipy.stats.entropy(p, q) print(KL_divergence(p,q)) # 0.011735745199107783 print(KL_divergence(q,p)) # 0.013183150978050884
JS散度
JS散度基於KL散度,同樣是二者越相似,JS散度越小。
- JS散度的取值範圍在0-1之間,完全相同時為0
- JS散度是對稱的
python3程式碼:
import numpy as np import scipy.stats p=np.asarray([0.65,0.25,0.07,0.03]) q=np.array([0.6,0.25,0.1,0.05]) q2=np.array([0.1,0.2,0.3,0.4]) def JS_divergence(p,q): M=(p+q)/2 return 0.5*scipy.stats.entropy(p, M)+0.5*scipy.stats.entropy(q, M) print(JS_divergence(p,q)) # 0.003093977084273652 print(JS_divergence(p,q2)) # 0.24719159952098618 print(JS_divergence(p,p)) # 0.0
實例:身高分布預測比較
在實際運用中,我們往往並不是一開始就能得到概率分布的,我們手上的更多是像每個人的身高這樣的具體數據,那麼怎麼在python把它們轉化為概率分布然後衡量距離呢? 我們需要把數據等間隔地切分成一些區間(也叫作桶bin),然後就可以把樣本落在每個區間的概率作為分布。pandas提供了cut
這個方便的函數可以完成這一點。 下面我將演示一個身高分布預測比較的例子,用scipy的正態分布函數隨機生成了真實的身高分布和兩個預測,讓我們用散度來評判哪個是更好的預測: 上程式碼:
from scipy.stats import norm import pandas as pd #1000個均值170,標準差10的正態分布身高樣本 h_real = norm.rvs(loc=170, scale=10, size=1000) h_predict1 = norm.rvs(loc=168, scale=9, size=1000) h_predict2 = norm.rvs(loc=160, scale=20, size=1000) def JS_div(arr1,arr2,num_bins): max0 = max(np.max(arr1),np.max(arr2)) min0 = min(np.min(arr1),np.min(arr2)) bins = np.linspace(min0-1e-4, max0-1e-4, num=num_bins) PDF1 = pd.cut(arr1,bins).value_counts() / len(arr1) PDF2 = pd.cut(arr2,bins).value_counts() / len(arr2) return JS_divergence(PDF1.values,PDF2.values) print(JS_div(h_real,h_predict1,num_bins=20)) # 0.0098 print(JS_div(h_real,h_predict2,num_bins=20)) # 0.135
我為預測1設置的參數和真實值更加接近。而預測1的散度的確更低,說明它是更好的預測。