K-近鄰演算法kNN
K-近鄰演算法(k-Nearest Neighbor,簡稱kNN)採用測量不同特徵值之間的距離方法進行分類,是一種常用的監督學習方法,其工作機制很簡單:給定測試樣本,基於某種距離亮度找出訓練集中與其靠近的k個訓練樣本,然後基於這k個「鄰居」的資訊進行預測。kNN演算法屬於懶惰學習,此類學習技術在訓練階段僅僅是把樣本保存起來,訓練時間靠小為零,在收到測試樣本後在進行處理,所以可知kNN演算法的缺點是計算複雜度高、空間複雜度高。但其也有優點,精度高、對異常值不敏感、無數據輸入設定。
借張圖來說:
當k = 1時目標點有一個class2鄰居,根據kNN演算法的原理,目標點也為class2。
當k = 5時目標點有兩個class2鄰居,有三個class1的鄰居,根據其原理,目標點的類別為class2。
演算法流程
總體來說,KNN分類演算法包括以下4個步驟:
①準備數據,對數據進行預處理 。
②計算測試樣本點(也就是待分類點)到其他每個樣本點的距離。
③對每個距離進行排序,然後選擇出距離最小的K個點 。
④對K個點所屬的類別進行比較,根據少數服從多數的原則,將測試樣本點歸入在K個點中佔比最高的那一類 。
演算法程式碼


package com.top.knn; import com.top.constants.OrderEnum; import com.top.matrix.Matrix; import com.top.utils.MatrixUtil; import java.util.*; /** * @program: top-algorithm-set * @description: KNN k-臨近演算法進行分類 * @author: Mr.Zhao * @create: 2020-10-13 22:03 **/ public class KNN { public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception { if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) { throw new IllegalArgumentException("矩陣訓練集與標籤維度不一致"); } if (input.getMatrixColCount() != dataSet.getMatrixColCount()) { throw new IllegalArgumentException("待分類矩陣列數與訓練集列數不一致"); } if (dataSet.getMatrixRowCount() < k) { throw new IllegalArgumentException("訓練集樣本數小於k"); } // 歸一化 int trainCount = dataSet.getMatrixRowCount(); int testCount = input.getMatrixRowCount(); Matrix trainAndTest = dataSet.splice(2, input); Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1); trainAndTest = (Matrix) normalize.get("res"); dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount()); input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount()); // 獲取標籤資訊 List<Double> labelList = new ArrayList<>(); for (int i = 0; i < labels.getMatrixRowCount(); i++) { if (!labelList.contains(labels.getValOfIdx(i, 0))) { labelList.add(labels.getValOfIdx(i, 0)); } } Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]); for (int i = 0; i < input.getMatrixRowCount(); i++) { // 求向量間的歐式距離 Matrix var1 = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount()); Matrix var2 = dataSet.subtract(var1); Matrix var3 = var2.square(); Matrix var4 = var3.sumRow(); Matrix var5 = var4.pow(0.5); // 距離矩陣合併上labels矩陣 Matrix var6 = var5.splice(1, labels); // 將計算出的距離矩陣按照距離升序排序 var6.sort(0, OrderEnum.ASC); // 遍歷最近的k個變數 Map<Double, Integer> map = new HashMap<>(); for (int j = 0; j < k; j++) { // 遍歷標籤種類數 for (Double label : labelList) { if (var6.getValOfIdx(j, 1) == label) { map.put(label, map.getOrDefault(label, 0) + 1); } } } result.setValue(i, 0, getKeyOfMaxValue(map)); } return result; } /** * 取map中值最大的key * * @param map * @return */ private static Double getKeyOfMaxValue(Map<Double, Integer> map) { if (map == null) return null; Double keyOfMaxValue = 0.0; Integer maxValue = 0; for (Double key : map.keySet()) { if (map.get(key) > maxValue) { keyOfMaxValue = key; maxValue = map.get(key); } } return keyOfMaxValue; } }
KNN
註:其中的矩陣方法請參考//github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/matrix/Matrix.java
升降序枚舉類參考//github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/constants/OrderEnum.java
該演算法為本人github項目中的一部分,地址為//github.com/ineedahouse/top-algorithm-set
如果對你有幫助可以點個star~
參考
《機器學習》-周志華
《機器學習實戰》-Peter Harrington