文档章节

PyTorch快速入门教程十(GANs以及对抗网络)

earnpls
 earnpls
发布于 2017/07/04 22:44
字数 2678
阅读 104
收藏 0


GANs

GANs的全称叫做生成对抗网络,根据这个名字,你就可以猜测这个网络是由两部分组成的,第一部分是生成,第二部分是对抗。那么你已经基本猜对了,这个网络第一部分是生成网络,第二部分对抗模型严格来讲是一个判别器,简单来说呢,就是让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。

可以用这个图来简单的看一看这两个过程。

下面我们就来依次介绍。

Discriminator Network

首先我们来讲一下对抗过程,因为这个过程更加简单。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的label没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label是1表示真实的;而生成的假的图片的label是0表示假的。

我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片,这其实就是一个简单的二分类问题,对于这个问题可以用我们前面讲过的很多方法去处理,比如logistic回归,深层网络,卷积神经网络,循环神经网络都可以。

Generative Network

接着我们要看看如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,这个时候我们可以通过仿射变换,也就是xw+b将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、池化、激活函数处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片,这个时候我们如何去训练这个生成器呢?就是通过判别器来得到结果,然后希望增大判别器判别这个结果为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

如下图所示 Generative Network

以上的过程已经简单的阐述了生成对抗网络的学习过程,如果仍然不太清楚这个过程,下面我们会通过代码来更清晰地展示整个过程。

代码

我们会使用mnist手写数字来做数据集,通过生成对抗网络我们希望生成一些“以假乱真”的手写字体。为了加快训练过程,我们不使用卷积网络来做判别器,我们使用简单的多层网络来进行判别。

Discriminator Network

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.dis(x)
        return x

以上这个网络是一个简单的多层神经网络,将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。之所以使用LeakyRelu而不是用ReLU激活函数是因为经过实验LeakyReLU的表现更好。

Generative Network

class generator(nn.Module):
    def __init__(self, input_size):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.gen(x)
        return x

输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。

Discriminator Train

判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。

首先我们需要定义loss的度量方式和优化函数,loss度量使用二分类的交叉熵,油画函数注意使用的学习率是0.0003

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

接着进入训练

img = img.view(num_img, -1)  # 将图片展开乘28x28=784
real_img = Variable(img).cuda()  # 将tensor变成Variable放入计算图中
real_label = Variable(torch.ones(num_img)).cuda()  # 定义真实label为1
fake_label = Variable(torch.zeros(num_img)).cuda()  # 定义假的label为0

# compute loss of real_img
real_out = D(real_img)  # 将真实的图片放入判别器中
d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss  
real_scores = real_out  # 真实图片放入判别器输出越接近1越好

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 随机生成一些噪声
fake_img = G(z)  # 放入生成网络生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片
d_loss_fake = criterion(fake_out, fake_label)  # 得到假的图片的loss
fake_scores = fake_out  # 假的图片放入判别器越接近0越好

# bp and optimize
d_loss = d_loss_real + d_loss_fake  # 将真假图片的loss加起来
d_optimizer.zero_grad()  # 归0梯度
d_loss.backward()  # 反向传播
d_optimizer.step()  # 更新参数

我已经把每一步都注释在了代码上,这样更加便于大家阅读,这是一个判别器的训练过程,我们希望判别器能够正确辨别出真假图片。

Generative Train

在生成网络的训练中,我们希望生成一张假的图片,然后经过判别器之后希望他能够判断为真的图片,在这个过程中,我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过跟新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到随机噪声
fake_img = G(z)  # 生成假的图片
output = D(fake_img)  # 经过判别器得到结果
g_loss = criterion(output, real_label)  # 得到假的图片与真实图片label的loss

# bp and optimize
g_optimizer.zero_grad()  # 归0梯度
g_loss.backward()  # 反向传播
g_optimizer.step()  # 更新生成网络的参数

这样我们就写好了一个简单的生成网络,通过不断地训练我们希望能够生成很真的图片。

训练结果

通过不断训练,我们可以得到下面的图片

这是真实图片

Generative Train

第1幅为第一次生成的噪声图片,之后分别是跑完15次生成的图片,跑完30次,跑完50次,跑完70次,最后一个是跑完100次生成的图片

Generative Train

Generative Train

怎么样,是不是特别神奇,我们居然可以生成一副看着很真的图片,这里我们只是用了简单的多层感知器来生成和判别模型,我们可以用更复杂的卷积神经网络来做同样的事情,代码将和本文的代码放在一起,有兴趣的同学可以自己去看看,然后放几张卷积网络生成的图片

Generative Train

可以发现产生的噪声更少了,训练也更加稳定,主要是里面引入了Batchnormalization,另外gan的训练过程是特别困难的,两个对偶网络相互学习,这个时候有一些训练技巧可以使得训练生成更加稳定,详细见一下github

最后我们来说一下为何Gans能够成为最近20年来机器学习以及深度学习界革命性的发现。这是因为不管是深度学习还是机器学习仍然很大一部分是监督学习,但是创建这么多有label的数据集所需要的人力物力是极大的,同时遇到的新的任务时我们很容易得到原始的没有label的数据集,这是我们需要花大量的时间去给其标定label,所以很多人都认为无监督学习才是机器学习的未来,这个时候Gans的出现为无监督学习提供了有力的支持,这当然引起了学界的大量关注,同时基于Gans的应用也越来越多,业界对其也非常狂热。

最后引用Yan Lecun的话:”它(Gans)为创建无监督学习模型提供了强有力的算法框架,有望帮助我们为 AI 加入常识(common sense)。我们认为,沿着这条路走下去,有不小的成功机会能开发出更智慧的 AI 。”

以上我们简单的介绍了Gans,通过网络实现了手写字体的生成,当然还有更多的变形和应用,有兴趣的同学可以自己阅读相关论文深入了解。

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

 

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

本文转载自:https://ptorch.com/news/14.html

共有 人打赏支持
earnpls
粉丝 5
博文 26
码字总数 74
作品 0
昌平
程序员
PyTorch 你想知道的都在这里

本文转载地址,并进行了加工。本文适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文代码实现,包括 Attention Based CNN、A3C、WGAN、BERT等等。所有代码均按照所属技术领域分...

readilen
前天
0
0
PyTorch:60分钟入门学习

最近在学习PyTorch这个深度学习框架,在这里做一下整理分享给大家,有什么写的不对或者不好的地方,还请大侠们见谅啦~~~ 写在前面 本文就是主要是对PyTorch的安装,以及入门学习做了记录,...

与阳光共进早餐
01/15
0
0
这些资源你肯定需要!超全的GAN PyTorch+Keras实现集合

  选自GitHub   作者:eriklindernoren   机器之心编译   参与:刘晓坤、思源、李泽南      生成对抗网络一直是非常美妙且高效的方法,自 14 年 Ian Goodfellow 等人提出第一个生...

机器之心
04/24
0
0
Keras vs PyTorch:谁是「第一」深度学习框架?

  选自Deepsense.ai   作者:Rafa Jakubanis、Piotr Migdal   机器之心编译   参与:路、李泽南、李亚洲      「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题...

机器之心
06/30
0
0
教程 | PyTorch经验指南:技巧与陷阱

  选自GitHub   作者:Kaixhin   机器之心编译      PyTorch 的构建者表明,PyTorch 的哲学是解决当务之急,也就是说即时构建和运行计算图。目前,PyTorch 也已经借助这种即时运行...

机器之心
07/30
0
0

没有更多内容

加载失败,请刷新页面

加载更多

驼峰变量名的转换

package com.mmall.test;import java.util.regex.Matcher;import java.util.regex.Pattern;/** * 需求:1. 将字符串 user_name_abc 转换为 userNameAbc * 2. 将字符串 us......

蚂蚁-Declan
29分钟前
5
0
HTTP请求方法

根据HTTP标准,HTTP请求可以使用多种请求方法。 HTTP1.0定义了三种请求方法: GET, POST 和 HEAD方法。 HTTP1.1新增了五种请求方法:OPTIONS, PUT, DELETE, TRACE 和 CONNECT 方法。 序号 方...

踏破铁鞋无觅处
33分钟前
2
0
知识点043-selenium自动化测试网页工具的使用

【摘要】 Selenium是一个主要用于Web应用自动化测试的工具集合。但其作用不仅仅局限于测试领域,还可以用于浏览器行为模拟以及屏幕抓取等,在行业内有着广泛的应用。Selenium支持主流的浏览器...

侠客行之石头
40分钟前
1
0
B250F I219V安装windows server 网卡驱动

https://blog.csdn.net/ryu2003/article/details/50855146

梦想游戏人
40分钟前
1
0
MacOS Install Docker

使用 Homebrew 安装 macOS 我们可以使用 Homebrew 来安装 Docker。 Homebrew 的 Cask 已经支持 Docker for Mac,因此可以很方便的使用 Homebrew Cask 来进行安装: $ brew cask install dock...

Linux就该这么学
40分钟前
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部