上QQ阅读APP看书,第一时间看更新
AI源码解读.数字图像处理案例:Python版
2.3.5 Pix2Pix模型构建
Pix2Pix模型构建需要定义生成器和判别器的神经网络,生成器使用U-Net网络结构,提取特征;判别器则根据PatchGAN方法获得不同区域的得分。
1.Pix2Pix数据集处理
当前制作出数据集的一张图片中包含两部分,可将图片左半部分设为X类,右半部分设为Y类。Pix2Pix图像翻译所要实现的功能将Y类样本翻译为X类,同样也可以用X类翻译为Y类样本,此时功能便是将风格图片还原为原图,如图2-21所示。
图2-21 Pix2Pix数据集样式
在导入模型前将图片分别保存到X数组和Y数组中,相关代码如下:
2.定义辅助函数
定义一些常量、网络张量、辅助函数,相关代码如下:
3.定义生成器和判别器
判别器:将X和Y按通道拼接,经过多次卷积后得到30*30*1的判别图,即PatchGAN的思想。生成器:Unet前后两部分各包含8层卷积,且后半部分的前3层卷积使用丢弃,它在训练过程中以一定概率随机去掉一些神经元,起到防止过拟合的作用,相关代码如下:
4.定义损失函数和优化器
根据生成器和判别器的损失定义函数:
判别器D的参数θd,损失关于参数的梯度为:+ln(1-D(G(z(i))))],生成器G的参数θg,损失关于参数的梯度为:,相应的总损失函数为:。