K均值演算法
一、概念
K-means中心思想:事先確定常數K,常數K意味著最終的聚類類別數,首先隨機選定初始點為質心,並通過計算每一個樣本與質心之間的相似度(這裡為歐式距離),將樣本點歸到最相似的類中,接著,重新計算每個類的質心(即為類中心),重複這樣的過程,直到質心不再改變,最終就確定了每個樣本所屬的類別以及每個類的質心。由於每次都要計算所有的樣本與每一個質心之間的相似度,故在大規模的數據集上,K-Means演算法的收斂速度比較慢。
二、特點:
常用距離
a.歐式距離
b.曼哈頓距離
三、演算法流程
K-means是一個反覆迭代的過程,演算法分為四個步驟:
(x,k,y)
(1) 選取數據空間中的K個對象作為初始中心,每個對象代表一個聚類中心;
def initcenter(x, k): kc
(2) 對於樣本中的數據對象,根據它們與這些聚類中心的歐氏距離,按距離最近的準則將它們分到距離它們最近的聚類中心(最相似)所對應的類;
def nearest(kc, x[i]): j
def xclassify(x, y, kc):y[i]=j
(3) 更新聚類中心:將每個類別中所有對象所對應的均值作為該類別的聚類中心,計算目標函數的值;
def kcmean(x, y, kc, k):
(4) 判斷聚類中心和目標函數的值是否發生改變,若不變,則輸出結果,若改變,則返回2)。
while flag:
y = xclassify(x, y, kc)
kc, flag = kcmean(x, y, kc, k)
四、實踐
(1).撲克牌手動演練k均值聚類過程:>30張牌,3類
①本次模擬k均值用到的撲克牌,初始中心為(2,9,12)
②經過一輪計算(選出中心:3,8,12)
③一直算到最後
(2).*自主編寫K-means演算法 ,以鳶尾花花瓣長度數據做聚類,並用散點圖顯示。
### 1、導入鳶尾花數據 from sklearn.datasets import load_iris import numpy as np ### 2、鳶尾花數據 iris = load_iris() data=iris['data'] #樣本屬性個數 m=data.shape[1] #樣本個數 n=len(data) #類中心個數,即最終分類 k=3 ### 3、數據初始化 #距離矩陣 dist=np.zeros([n,k+1]) #初始類中心 center=np.zeros([k,m]) #新的類中心 new_center=np.zeros([k,m]) ### 4、選中心 #選擇前三個樣本作為初始類中心 center=data[:k, :] while True: #求距離 for i in range(n): for j in range(k): dist[i,j]=np.sqrt(sum((data[i,:]-center[j,:])**2)) #歸類 dist[i,k]=np.argmin(dist[i,:k]) #求新類中心 for i in range(k): index=dist[:,k]==i new_center[i,:]=np.mean(data[index, :]) #判斷結束 if(np.all(center==new_center)): break else: center=new_center print('聚類結果:',dist[:,k])
(3)用sklearn.cluster.KMeans,鳶尾花花瓣長度數據做聚類,並用散點圖顯示。
from sklearn.datasets import load_iris from sklearn.cluster import KMeans import matplotlib.pyplot as plt # 獲取鳶尾花數據集 iris = load_iris() data = iris.data[:, 1] # 鳶尾特徵值 x = data.reshape(-1, 1) # 構建模型 model = KMeans(n_clusters=3) # 訓練 model.fit(x) # 預測樣本的聚類索引 y = model.predict(x) print("預測結果:", y) #畫圖 plt.scatter(x[:, 0], x[:, 0], c=y, s=50, cmap='rainbow') plt.show()
預測結果:
散點圖可視化:
(4)鳶尾花完整數據做聚類並用散點圖顯示。
from sklearn.datasets import load_iris from sklearn.cluster import KMeans import matplotlib.pyplot as plt # 導入鳶尾花數據集 iris = load_iris() # 鳶尾花花瓣長度數據 x = iris.data # 構建模型 model = KMeans(n_clusters=3) # 訓練 model.fit(x) # 預測 y = model.predict(x) print("預測結果:", y) #畫圖 plt.scatter(x[:, 2], x[:, 3], c=y, s=50, cmap='rainbow') plt.show()
預測結果:
散點圖可視化:
(5)想想k均值演算法中以用來做什麼?
- 文本分析和歸類
- K均值演算法實現影像壓縮
- 像素處理
- K均值演算法處理影像