MNIST竞赛技术详解与代码分析,文末有福利~
- 2019 年 11 月 30 日
- 筆記
Part.1 项目简介
MNIST项目基本上是深度学习初学者的入门项目,本文主要介绍使用keras框架通过构建CNN网络实现在MNIST数据集上99+的准确率。温馨提示,文末有福利哦。
Part.2 数据集来源
MNIST手写数字数据集是深度学习中的经典数据集,该数据集中的数字图片是由250个不同职业的人手写绘制的。
其中,训练集数据一共60000张图片,测试集数据一共10000张图片。
每张手写数字图片大小都是28*28,每张图片代表的是从0到9中的每个数字。
该数据集样例如下所示:
在 FlyAI竞赛平台上 提供了准确率为99.26%的超详细代码实现,同时我们可以通过参加MNIST手写数字识别练习赛进行进一步学习和优化。
下面的代码实现部分主要该代码进行讲解。
Part.3 代码实现
3.1、算法流程及实现
算法流程主要分为以下四个部分进行介绍:
1.数据加载
2.数据增强
3.构建网络
4.模型训练
1.数据加载
在FlyAI的项目中封装了Dataset类,可以实现对数据的一些基本操作。
比如:
加载批量数据next_train_batch()、
校验数据next_validation_batch()、
获取全量数据get_all_data()、
获取训练集数据量get_train_length()
等。
具体使用方法如下:
对单张图片等数据的读取是在processor.py文件中完成。实现如下:
2.数据增强
数据增强的作用通常是为了扩充训练数据量提高模型的泛化能力,同时通过增加了噪声数据提升模型的鲁棒性。
在本项目中我们采用了比较简单的数据增强方法包括旋转、平移。实现如下:
为了展示数据增强的效果,我们对图像进行了可视化,完整代码如下:
可视化结果如图:
3.构建网络
由于手写数字图片大小仅为28*28,图像宽高比较小不太适合较深的网络结构。
因此我们自己搭建了一个卷积神经网络,网络结构如下所示:
运行summary()方法后输出的网络结构如下图:
keras提供了keras.utils.vis_utils模块可以对模型进行可视化操作。
模型结构图如下所示:
4.模型训练
这里我们设置了epoch为5,batch为32,采用adam优化器来训练网络。
通过调用FlyAI提供的train_log方法可以在训练过程中实时的看到训练集和验证集的准确率及损失变化曲线。
训练集和验证集的准确率及损失实时变化曲线如图:
3.2、最终结果
通过使用自定义CNN网络结构以及数据增强的方法,在epoch为5,batch为32使用adam优化器下不断优化模型参数,最终模型在测试集的准确率达到99.26%。
_ END _