1.9 EfficientNet-B7建模
本节采用EfficientNet-B7模型,基于1.3节介绍的包含104种花朵的数据集,完成模型训练、验证和测试工作。1.10节将训练好的模型部署到Web服务器上,供客户端访问。
EfficientNet-B7参数量较大,输入图像采用的分辨率为512×512像素,对计算力需求较大,建模过程借助Kaggle平台提供的免费TPUv3-8完成。本节实现的完整建模逻辑与模型下载参见网址:https://www.kaggle.com/upsunny/flower-efficientnetb7。
学习率是模型训练过程中应该引起足够重视的一个超参数,学习率应该随着训练过程做动态调整,程序源码P1.5给出的调度策略非常经典,学习率先上升后下降,这种策略常见于项目实战中。
程序源码P1.5给出的学习率调度函数,参照EPOCH这个变量,将学习率的变化限制在一个范围内,即划定一个从最小值到最大值的变化区间,在这个区间内包含三段变化过程。
(1)第一阶段学习率从最小值开始不断上升,由LR_RAMPUP_EPOCHS控制上升的代数。
(2)第二阶段学习率保持不变,由LR_SUSTAIN_EPOCHS控制保持不变的代数。
(3)第三阶段学习率衰减,由LR_EXP_DECAY控制衰减策略。
学习率跟随EPOCH动态变化的规律如图1.30所示,学习率从最小值0.000 01起步增长到0.000 40,没有超过最大值;因LR_SUSTAIN_EPOCHS=0,中间跳过了学习率维持阶段,直接进入衰减阶段,在第13个EPOCH时,学习率衰减到0.000 075。
程序中定义了用于数据集加载的预处理函数与显示图像函数、显示模型训练曲线函数和显示混淆矩阵的函数,分别用于观察数据集、观察模型训练效果和预测结果,此处不再赘述,参见视频解析。
图1.30 学习率随EPOCH动态变化的规律
EfficientNet-B7迁移模型的编程逻辑如程序源码P1.6所示。注意,第5行语句表明模型将采用微调模式进行训练。
观察模型结构摘要,EfficientNet-B7迁移模型结构参数如表1.12所示。输出向量的维度为(None,104),可训练参数总量超过6400万个。
表1.12 EfficientNet-B7迁移模型结构参数
模型训练13代在TPUv3-8上用时25min左右,损失函数与准确率对照曲线如图1.31所示。
图1.31 损失函数与准确率对照曲线
从图1.31可以观察到EfficientNet-B7的下列表现。
(1)损失函数与准确率曲线走势高度趋同,证明模型结构合理并且稳定。
(2)训练集与验证集走势高度趋同,曲线平滑,无显著背离和波动,证明模型稳定性好,泛化能力强。
(3)包含104种花朵类别,训练集和验证集样本总量不到17 000,在平均单个类别不足170幅图片的前提下,得到的模型损失值趋于0,准确率趋于1,证明模型应用价值高。
本节在Kaggle上展示的教学演示,给出了104种花朵的混淆矩阵,可以更加直观地观察到EfficientNet-B7的误差所在。
同时,模型的F1-Score为0.952,精准率(Precision)为0.952,召回率(Recall)为0.955,这三个指标高度一致,而且超过95%,再次证明EfficientNet-B7的可靠性。
模型在验证集上的抽样测试结果如图1.32所示,标签旁边的[OK]表示预测正确。一般而言,程序员这个时候最为兴奋,好的结果证明了模型的价值。
测试集包含7000多幅图片,没有参与模型训练,图1.33是其中的一组随机抽样推断结果,图片上方给出的标签是模型的预测结果。
由于测试集没有给出标签,因此需要人工观察结果的正确性。事实上,找出一幅预测结果错误的图片,还是有难度的,可以从混淆矩阵给出的错误报告中做针对性测试。
图1.32 模型在验证集上的抽样测试结果
图1.33 测试集上的随机抽样推断结果
更多测试及建模解析参见视频教程。下载训练好的模型,该模型将在1.10节部署到服务器上。