基于PaddlePaddle的强化学习算法DCGAN

原创
2020/04/20 15:59
阅读数 981

简介

生成对抗网络(Generative Adversarial Network[1], 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。

DCGAN是将CNN与GAN的一种结合。 其将卷积网络引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果,将GAN和卷积网络结合起来,可以解决GAN训练不稳定的问题,利用卷积神经网络作为网络结构进行图像生成,可以得到更加丰富的层次表达。DCGAN的贡献就在于:为CNN的网络拓扑结构设置了一系列的限制来使得它可以稳定的训练; 使用得到的特征表示来进行图像分类,得到比较好的效果来验证生成的图像特征表示的表达能力; 对GAN学习到的filter进行了定性的分析; 展示了生成的特征表示的向量计算特性。

DCGAN的原理和GAN对抗生成是一样的。它只是把GAN的G和D换成了两个卷积神经网络(CNN)。但不是直接换就可以了,DCGAN对卷积神经网络的结构做了一些改变,以提高样本的质量和收敛的速度,这些改变有:

1.取消所有pooling层。G网络中使用转置卷积(transposed convolutional layer)进行上采样,D网络中用加入stride的卷积代替pooling。

  1. 除了生成器模型的输出层和判别器模型的输入层,在网络其它层上都使用了Batch Normalization,使用BN可以稳定学习,有助于处理初始化不良导致的训练问题。

  2. 去掉全连接层,使网络变为全卷积网络

  3. G网络中使用ReLU作为激活函数,最后一层使用tanh

  4. D网络中使用LeakyReLU作为激活函数

DCGAN的generator网络结构:

DCGAN训练10轮的模型预测效果如图3所示:

阅读本项目之前建议阅读原版论文https://arxiv.org/abs/1511.06434 ,优秀解读博客https://blog.csdn.net/stdcoutzyx/article/details/53872121

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

本项目采用mnist数据集

In[3]
# 代码结构
# ├── network.py   # 定义基础生成网络和判别网络。
# ├── utility.py   # 定义通用工具方法。
# ├── dc_gan.py    # DCGAN训练脚本。
# ├── infer.py    # 预测脚本。
# ├── reader.py    # 数据读取脚本。
#freeze_model中保存固化的模型。
In[1]
# 训练过程中,每隔固定的训练轮数,会取一个batch的数据进行测试,测试结果以图片的形式保存至--output选项指定的路径。
# 执行python dc_gan.py --help可查看更多使用方式和参数详细说明。
#在GPU上训练CGAN,测试结果以图片的形式保存至--output选项指定的路径。
!python dc_gan/dc_gan.py --epoch=3 --output="./DC_result" --use_gpu=True
2020-03-06 15:58:50,529-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2020-03-06 15:58:50,840-INFO: generated new fontManager
-----------  Configuration Arguments -----------
batch_size: 128
epoch: 3
output: ./DC_result
run_ce: False
use_gpu: 1
------------------------------------------------
W0306 15:58:52.030972    91 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0306 15:58:52.039597    91 device_context.cc:245] device: 0, cuDNN Version: 7.3.
Epoch ID=0 Batch ID=0 D-Loss=1.439273476600647 DG-Loss=0.5234684944152832
 gen=[0.0035710295, -0.055501763, 0.08866583, -0.0062621743, 0.012583672]
Epoch ID=0 Batch ID=50 D-Loss=1.2131903171539307 DG-Loss=0.6392121315002441
 gen=[0.074199796, -0.98728657, 0.96557516, -0.18577422, 0.38847187]
Epoch ID=0 Batch ID=100 D-Loss=1.199127435684204 DG-Loss=0.6538589596748352
 gen=[-0.15864794, -0.9990916, 0.97238547, -0.49475494, 0.13215934]
Epoch ID=0 Batch ID=150 D-Loss=1.1499663591384888 DG-Loss=0.6774157285690308
 gen=[-0.41129252, -0.99988854, 0.99251664, -0.74249625, -0.13403124]
Epoch ID=0 Batch ID=200 D-Loss=1.1407439708709717 DG-Loss=0.6751459240913391
 gen=[-0.41557506, -0.9999451, 0.9985463, -0.71916705, -0.20394036]
Epoch ID=0 Batch ID=250 D-Loss=1.1618539094924927 DG-Loss=0.6718940138816833
 gen=[-0.51218396, -0.99997926, 0.9997986, -0.85931814, -0.3323178]
Epoch ID=0 Batch ID=300 D-Loss=1.1419421434402466 DG-Loss=0.6702679395675659
 gen=[-0.60985285, -1.0, 0.9999937, -0.9458067, -0.52866304]
Epoch ID=0 Batch ID=350 D-Loss=1.1823350191116333 DG-Loss=0.6736178398132324
 gen=[-0.69451416, -1.0, 0.99999994, -0.98746, -0.74528503]
Epoch ID=0 Batch ID=400 D-Loss=1.1754354238510132 DG-Loss=0.6745465397834778
 gen=[-0.6880498, -1.0, 1.0, -0.99424547, -0.8043335]
Epoch ID=0 Batch ID=450 D-Loss=1.176548957824707 DG-Loss=0.6788198947906494
 gen=[-0.74726224, -1.0, 1.0, -0.9987327, -0.9204884]
Epoch ID=1 Batch ID=0 D-Loss=1.1447052955627441 DG-Loss=0.6708738803863525
 gen=[-0.7079931, -1.0, 1.0, -0.9981714, -0.88169056]
Epoch ID=1 Batch ID=50 D-Loss=1.1636065244674683 DG-Loss=0.6714401245117188
 gen=[-0.70650625, -1.0, 1.0, -0.9992901, -0.8929729]
Epoch ID=1 Batch ID=100 D-Loss=1.1782646179199219 DG-Loss=0.6642638444900513
 gen=[-0.6976415, -1.0, 1.0, -0.99960464, -0.8931533]
Epoch ID=1 Batch ID=150 D-Loss=1.1710419654846191 DG-Loss=0.6807173490524292
 gen=[-0.75162506, -1.0, 1.0, -0.9999121, -0.94781107]
Epoch ID=1 Batch ID=200 D-Loss=1.1481549739837646 DG-Loss=0.6630681753158569
 gen=[-0.69867957, -1.0, 1.0, -0.99988514, -0.9130186]
Epoch ID=1 Batch ID=250 D-Loss=1.1682077646255493 DG-Loss=0.6757701635360718
 gen=[-0.8018313, -1.0, 0.9999999, -0.9999951, -0.9871076]
Epoch ID=1 Batch ID=300 D-Loss=1.148134469985962 DG-Loss=0.663914680480957
 gen=[-0.7266686, -1.0, 0.99999964, -0.9999672, -0.9503533]
Epoch ID=1 Batch ID=350 D-Loss=1.1521973609924316 DG-Loss=0.676605224609375
 gen=[-0.70590776, -1.0, 0.99999994, -0.99997187, -0.9425734]
Epoch ID=1 Batch ID=400 D-Loss=1.1882412433624268 DG-Loss=0.6716219782829285
 gen=[-0.7665794, -1.0, 0.99999964, -0.9999983, -0.98480856]
Epoch ID=1 Batch ID=450 D-Loss=1.177868127822876 DG-Loss=0.6702796220779419
 gen=[-0.76936704, -1.0, 0.9999978, -0.9999989, -0.9852225]
Epoch ID=2 Batch ID=0 D-Loss=1.177512764930725 DG-Loss=0.6751676797866821
 gen=[-0.743871, -1.0, 0.9999992, -0.99999785, -0.97568905]
Epoch ID=2 Batch ID=50 D-Loss=1.1480481624603271 DG-Loss=0.666557788848877
 gen=[-0.7171929, -1.0, 0.99999934, -0.99999666, -0.9606299]
Epoch ID=2 Batch ID=100 D-Loss=1.1683597564697266 DG-Loss=0.6695414781570435
 gen=[-0.7509831, -1.0, 0.99999946, -0.99999964, -0.98408717]
Epoch ID=2 Batch ID=150 D-Loss=1.1684950590133667 DG-Loss=0.673043429851532
 gen=[-0.7485244, -1.0, 0.9999997, -0.99999976, -0.9818194]
Epoch ID=2 Batch ID=200 D-Loss=1.173353910446167 DG-Loss=0.6461285352706909
 gen=[-0.74712914, -1.0, 0.999995, -0.9999998, -0.98147106]
Epoch ID=2 Batch ID=250 D-Loss=1.2059048414230347 DG-Loss=0.6585644483566284
 gen=[-0.7696446, -1.0, 0.99998784, -0.99999994, -0.99173194]
Epoch ID=2 Batch ID=300 D-Loss=1.1470997333526611 DG-Loss=0.6197065114974976
 gen=[-0.74291414, -1.0, 0.99999666, -0.99999994, -0.9858818]
Epoch ID=2 Batch ID=350 D-Loss=1.1451077461242676 DG-Loss=0.6631172895431519
 gen=[-0.75821966, -1.0, 0.99999917, -0.99999994, -0.988426]
Epoch ID=2 Batch ID=400 D-Loss=1.2104921340942383 DG-Loss=0.6530779600143433
 gen=[-0.7679337, -1.0, 0.99999166, -1.0, -0.9923828]
Epoch ID=2 Batch ID=450 D-Loss=1.1789915561676025 DG-Loss=0.6635096073150635
 gen=[-0.7241944, -1.0, 0.9999962, -0.99999994, -0.97801983]
In[2]
#利用固化的模型(训练3轮)进行预测,batch_size为生成图片张数,结果保存在output中
!python dc_gan/infer.py --output="./infer_result" --batch_size=16 --use_gpu=False

# 可视化预测效果
%matplotlib inline
import matplotlib.pyplot as plt  
import numpy as np
import cv2

img= cv2.imread('infer_result/generated_image.png')
plt.imshow(img)
plt.show()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:804: UserWarning: There are no operators in the program to be executed. If you pass Program manually, please use fluid.program_guard to ensure the current Program is being used.
  warnings.warn(error_info)

 点击链接,使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/169449

下载安装命令

## CPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安装命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

>> 访问 PaddlePaddle 官网,了解更多相关内容

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部