神經網路實現fashion數據集

  • 2020 年 8 月 12 日
  • 筆記
import tensorflow as tf
import numpy as np
fashion=tf.keras.datasets.fashion_mnist

(x_train,y_train),(x_test,y_test)=fashion.load_data()


model=tf.keras.Sequential([tf.keras.layers.Flatten(),
                          tf.keras.layers.Dense(128,activation='relu'),
                          tf.keras.layers.Dense(10,activation='softmax')])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train,y_train,batch_size=32,epochs=3,validation_data=(x_test,y_test),validation_freq=1)

model.summary()

註:數據集下載不成功的原因是沒掛vpn,掛vpn後即可成功下載!