Keras之對鳶尾花識別

Keras之隊鳶尾花識別

任務目標

  • 對鳶尾花數據集分析
  • 建立鳶尾花的模型
  • 利用模型預測鳶尾花的類別

環境搭建

pycharm編輯器搭建python3.*
第三方庫

  • numpy
  • pandas
  • sklearn
  • keras

處理鳶尾花數據集

了解數據集

鳶尾花數據集是一個經典的機器學習數據集,非常適合用來入門。
鳶尾花數據集鏈接:下載鳶尾花數據集
鳶尾花數據集包含四個特徵和一個標籤。這四個特徵確定了單株鳶尾花的下列植物學特徵:

  • 花萼長度
  • 花萼寬度
  • 花瓣長度
  • 花瓣寬度

該表確定了鳶尾花品種,品種必須是下列任意一種:

  • 山鳶尾 Iris-Setosa(0)
  • 雜色鳶尾 Iris-versicolor(1)
  • 維吉尼亞鳶尾 Iris-virginica(2)

數據集中三類鳶尾花各含有50個樣本,共150各樣本

下面顯示了數據集中的樣本:
樣本
機器學習中,為了保證測試結果的準確性,一般會從數據集中抽取一部分數據專門留作測試,其餘數據用於訓練。所以我將數據集按7:3(訓練集:測試集)的比例進行劃分。

數據集處理具體程式碼

# 讀取數據集
iris = pd.read_csv("iris.data", header=None)

# 數據集轉化成數組
iris = np.array(iris)
# 提取特徵集
X = iris[:, 0:4]
# 提取標籤集
Y = iris[:, 4]

# One-Hot編碼
encoder = LabelEncoder()
Y = encoder.fit_transform(Y)
Y = np_utils.to_categorical(Y)

x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)
return x_train,x_test,y_train,y_test

什麼是one-hot編碼?

  One-Hot編碼,又稱為一位有效編碼,主要是採用N位狀態暫存器來對N個狀態進行編碼,每個狀態都由他獨立的暫存器位,並且在任意時候只有一位有效。
  One-Hot編碼是分類變數作為二進位向量的表示。這首先要求將分類值映射到整數值。然後,每個整數值被表示為二進位向量,除了整數的索引之外,它都是零值,它被標記為1。
  One-Hot編碼是將類別變數轉換為機器學習演算法易於利用的一種形式的過程。
  比如:[“山鳶尾”,”雜色鳶尾”,”維吉尼亞鳶尾”]—->[[1,0,0][0,1,0][0,0,1]]


建立模型和預測

設置超參數

# 超參數
epochs = 500  # 循環次數
validation_split = 0.05  # 學習率
test_size = 0.25  # 拆分數據集大小
dense1_neurons = 512 # 第一層神經元的數量
dense2_neurons = 256 # 第二層神經元的數量
dense3_neurons = 128 # 第三層神經元的數量

搭建模型

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=dense1_neurons,input_dim = 4,activation = 'relu'))
model.add(tf.keras.layers.Dense(units=dense2_neurons,activation='relu'))
model.add(tf.keras.layers.Dense(units=dense3_neurons,activation='relu'))
model.add(tf.keras.layers.Dense(units=3,activation="softmax"))
model.summary()   # 查看模型結構

編譯模型

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

訓練模型

history = model.fit(x_train,y_train,validation_split=validation_split,epochs=epochs)

使用測試集進行評估

model.evaluate(x_test,y_test)

預測

target = model.predict(np.array([[7, 5.5, 6.5, 3.9]])).argmax()
print(target)
if target == 0:
    print("Iris-setosa")
elif target == 1:
    print("Iris-versicolor")
else:
    print("Iris-virginica")

結果圖片