AI源码解读:数字图像处理案例(Python版)
上QQ阅读APP看书,第一时间看更新
 AI源码解读.数字图像处理案例:Python版

2.3.3 模型训练及保存

定义卷积神经网络模型架构和编译之后,使用训练集训练模型,使模型获得任意内容图片的风格迁移结果。这里,每次使用4张图片训练模型,训练后输出其中3张效果图保存到TensorBoard中。

1.模型训练

模型训练相关代码如下:

其中,一个batch就是在一次前向/后向传播过程用到的训练样例数量,也就是一次用4张图片进行训练,共训练30000张图片,如图2-14所示。

图2-14 训练结果

通过TensorBoard观察当前训练的情况,如图2-15和图2-16所示,可以查看当前的内容损失和风格损失情况呈梯度下降的状态。

图2-15 TensorBoard参数(1)

图2-16 TensorBoard参数(2)

2.模型保存

为直接使用模型,需要将模型保存,使用TensorFlow中的train模块实现。

     saver = tf.train.Saver()
     saver.save(sess, os.path.join(OUTPUT_DIR, 'fast_style_transfer'))

模型保存后,可以在其他项目中直接使用。