上QQ阅读APP看书,第一时间看更新
AI源码解读.数字图像处理案例:Python版
3.3.3 模型训练及保存
在定义模型架构和编译之后,通过训练集训练模型,识别花卉。这里,将使用训练集和测试集拟合、改进,并保存模型。
1.模型训练
本部分包括CNN模型训练和Inception-V3模型训练。
1)CNN模型训练
CNN模型是本项目对于花卉分类的基本模型,相关代码如下:
训练输出结果如图3-8所示。
通过观察训练集和测试集的损失函数、准确率的大小评估模型的训练程度,进行模型训练的进一步决策。
2)Inception-V3模型训练
图3-8 训练输出结果
CNN模型对于花卉的分类准确率大概在70%左右。因此,得出的改进方法为,采用迁移学习调用Inception-v3模型实现对本文中的花卉数据集分类。
Inception系列解决CNN分类模型的两个问题:①如何使网络深度增加的同时让模型的分类性能随之增加,而非像简单的VGG网络达到一定深度后就陷入了性能饱和的困惑。②如何在保证分类网络准确率提升或保持不降的同时,使模型的计算开销与内存开销降低。在这个模型中最后一层全连接层之前统称为瓶颈层。Inception-v3模型下载地址为https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip,相关代码如下:
训练输出结果如图3-9所示。
使用Inception-V3模型的分类准确率在95%左右,准确率得到了较好的改善。经过对比,选择准确率更高的Inception-V3模型进行分类。
2.模型保存
为能够被Android程序读取,需要将模型文件保存为.pb格式,利用TensorFlow中的graph_util模块进行模型保存。
图3-9 训练输出结果
模型被保存后,可以被重用,也可以移植到其他环境中使用。