K-近鄰演算法kNN

  K-近鄰演算法(k-Nearest Neighbor,簡稱kNN)採用測量不同特徵值之間的距離方法進行分類,是一種常用的監督學習方法,其工作機制很簡單:給定測試樣本,基於某種距離亮度找出訓練集中與其靠近的k個訓練樣本,然後基於這k個「鄰居」的資訊進行預測。kNN演算法屬於懶惰學習,此類學習技術在訓練階段僅僅是把樣本保存起來,訓練時間靠小為零,在收到測試樣本後在進行處理,所以可知kNN演算法的缺點是計算複雜度高、空間複雜度高。但其也有優點,精度高、對異常值不敏感、無數據輸入設定。

  借張圖來說:

當k = 1時目標點有一個class2鄰居,根據kNN演算法的原理,目標點也為class2。

當k = 5時目標點有兩個class2鄰居,有三個class1的鄰居,根據其原理,目標點的類別為class2。

演算法流程

總體來說,KNN分類演算法包括以下4個步驟:

①準備數據,對數據進行預處理 。

②計算測試樣本點(也就是待分類點)到其他每個樣本點的距離。

③對每個距離進行排序,然後選擇出距離最小的K個點 。

④對K個點所屬的類別進行比較,根據少數服從多數的原則,將測試樣本點歸入在K個點中佔比最高的那一類 。

演算法程式碼

package com.top.knn;

import com.top.constants.OrderEnum;
import com.top.matrix.Matrix;
import com.top.utils.MatrixUtil;

import java.util.*;


/**
 * @program: top-algorithm-set
 * @description: KNN k-臨近演算法進行分類
 * @author: Mr.Zhao
 * @create: 2020-10-13 22:03
 **/
public class KNN {
    public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception {
        if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) {
            throw new IllegalArgumentException("矩陣訓練集與標籤維度不一致");
        }
        if (input.getMatrixColCount() != dataSet.getMatrixColCount()) {
            throw new IllegalArgumentException("待分類矩陣列數與訓練集列數不一致");
        }
        if (dataSet.getMatrixRowCount() < k) {
            throw new IllegalArgumentException("訓練集樣本數小於k");
        }
        // 歸一化
        int trainCount = dataSet.getMatrixRowCount();
        int testCount = input.getMatrixRowCount();
        Matrix trainAndTest = dataSet.splice(2, input);
        Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1);
        trainAndTest = (Matrix) normalize.get("res");
        dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount());
        input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount());

        // 獲取標籤資訊
        List<Double> labelList = new ArrayList<>();
        for (int i = 0; i < labels.getMatrixRowCount(); i++) {
            if (!labelList.contains(labels.getValOfIdx(i, 0))) {
                labelList.add(labels.getValOfIdx(i, 0));
            }
        }

        Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]);
        for (int i = 0; i < input.getMatrixRowCount(); i++) {
            // 求向量間的歐式距離
            Matrix var1 = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount());
            Matrix var2 = dataSet.subtract(var1);
            Matrix var3 = var2.square();
            Matrix var4 = var3.sumRow();
            Matrix var5 = var4.pow(0.5);
            // 距離矩陣合併上labels矩陣
            Matrix var6 = var5.splice(1, labels);
            // 將計算出的距離矩陣按照距離升序排序
            var6.sort(0, OrderEnum.ASC);
            // 遍歷最近的k個變數
            Map<Double, Integer> map = new HashMap<>();
            for (int j = 0; j < k; j++) {
                // 遍歷標籤種類數
                for (Double label : labelList) {
                    if (var6.getValOfIdx(j, 1) == label) {
                        map.put(label, map.getOrDefault(label, 0) + 1);
                    }
                }
            }
            result.setValue(i, 0, getKeyOfMaxValue(map));
        }
        return result;
    }

    /**
     * 取map中值最大的key
     *
     * @param map
     * @return
     */
    private static Double getKeyOfMaxValue(Map<Double, Integer> map) {
        if (map == null)
            return null;
        Double keyOfMaxValue = 0.0;
        Integer maxValue = 0;
        for (Double key : map.keySet()) {
            if (map.get(key) > maxValue) {
                keyOfMaxValue = key;
                maxValue = map.get(key);
            }
        }
        return keyOfMaxValue;
    }

}

KNN

註:其中的矩陣方法請參考//github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/matrix/Matrix.java

  升降序枚舉類參考//github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/constants/OrderEnum.java

該演算法為本人github項目中的一部分,地址為//github.com/ineedahouse/top-algorithm-set

如果對你有幫助可以點個star~

參考

《機器學習》-周志華

《機器學習實戰》-Peter Harrington