上QQ阅读APP看书,第一时间看更新
AI源码解读.数字图像处理案例:Python版
2.3.3 模型训练及保存
定义卷积神经网络模型架构和编译之后,使用训练集训练模型,使模型获得任意内容图片的风格迁移结果。这里,每次使用4张图片训练模型,训练后输出其中3张效果图保存到TensorBoard中。
1.模型训练
模型训练相关代码如下:
其中,一个batch就是在一次前向/后向传播过程用到的训练样例数量,也就是一次用4张图片进行训练,共训练30000张图片,如图2-14所示。
图2-14 训练结果
通过TensorBoard观察当前训练的情况,如图2-15和图2-16所示,可以查看当前的内容损失和风格损失情况呈梯度下降的状态。
图2-15 TensorBoard参数(1)
图2-16 TensorBoard参数(2)
2.模型保存
为直接使用模型,需要将模型保存,使用TensorFlow中的train模块实现。
saver = tf.train.Saver() saver.save(sess, os.path.join(OUTPUT_DIR, 'fast_style_transfer'))
模型保存后,可以在其他项目中直接使用。