文档章节

学习笔记TF051:生成式对抗网络

利炳根
 利炳根
发布于 2017/08/24 02:48
字数 1763
阅读 31
收藏 0
点赞 0
评论 0

生成式对抗网络(gennerative adversarial network,GAN),谷歌2014年提出网络模型。灵感自二人博弈的零和博弈,目前最火的非监督深度学习。GAN之父,Ian J.Goodfellow,公认人工智能顶级专家。

原理。 生成式对搞网络包含一个生成模型(generative model,G)和一个判别模型(discriminative model,D)。Ian J.Goodfellow、Jean Pouget-Abadie、Mehdi Mirza、Bing Xu、David Warde-Farley、Sherjil Ozair、Aaron Courville、Yoshua Bengio论文,《Generative Adversarial Network》,https://arxiv.org/abs/1406.2661 。 生成式对抗网络结构: 噪声数据->生成模型->假图片---| |->判别模型->真/假 打乱训练数据->训练集->真图片-| 生成式对抗网络主要解决如何从训练样本中学习出新样本。生成模型负责训练出样本的分布,如果训练样本是图片就生成相似的图片,如果训练样本是文章名子就生成相似的文章名子。判别模型是一个二分类器,用来判断输入样本是真实数据还是训练生成的样本。 生成式对抗网络优化,是一个二元极小极大博弈(minimax two-player game)问题。使生成模型输出在输入给判别模型时,判断模型秀难判断是真实数据还是虚似数据。训练好的生成模型,能把一个噪声向量转化成和训练集类似的样本。Argustus Odena、Christopher Olah、Jonathon Shlens论文《Coditional Image Synthesis with Auxiliary Classifier GANs》。 辅助分类器生成式对抗网络(auxiliary classifier GAN,AC-GAN)实现。

生成式对抗网络应用。生成数字,生成人脸图像。

生成式对抗网络实现。https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py 。 Augustus Odena、Chistopher Olah和Jonathon Shlens 论文《Conditional Image Synthesis With Auxiliary Classifier GANs》。 通过噪声,让生成模型G生成虚假数据,和真实数据一起送到判别模型D,判别模型一方面输出数据真/假,一方面输出图片分类。 首先定义生成模型,目的是生成一对(z,L)数据,z是噪声向量,L是(1,28,28)的图像空间。

def build_generator(latent_size):
    cnn = Sequential()
    cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))
    cnn.add(Dense(128 * 7 * 7, activation='relu'))
    cnn.add(Reshape((128, 7, 7)))
    #上采样,图你尺寸变为 14X14
    cnn.add(UpSampling2D(size=(2,2)))
    cnn.add(Convolution2D(256, 5, 5, border_mode='same', activation='relu', init='glorot_normal'))
    #上采样,图像尺寸变为28X28
    cnn.add(UpSampling2D(size=(2,2)))
    cnn.add(Convolution2D(128, 5, 5, border_mode='same', activation='relu', init='glorot_normal'))
    #规约到1个通道
    cnn.add(Convolution2D(1, 2, 2, border_mode='same', activation='tanh', init='glorot_normal'))
    #生成模型输入层,特征向量
    latent = Input(shape=(latent_size, ))
    #生成模型输入层,标记
    image_class = Input(shape=(1,), dtype='int32')
    cls = Flatten()(Embedding(10, latent_size, init='glorot_normal')(image_class))
    h = merge([latent, cls], mode='mul')
    fake_image = cnn(h) #输出虚假图片
    return Model(input=[latent, image_class], output=fake_image)

定义判别模型,输入(1,28,28)图片,输出两个值,一个是判别模型认为这张图片是否是虚假图片,另一个是判别模型认为这第图片所属分类。

def build_discriminator();
    #采用激活函数Leaky ReLU来替换标准的卷积神经网络中的激活函数
    cnn = Wequential()
    cnn.add(Convolution2D(32, 3, 3, border_mode='same', subsample=(2, 2), input_shape=(1, 28, 28)))
    cnn.add(LeakyReLU())
    cnn.add(Dropout(0.3))
    cnn.add(Convolution2D(64, 3, 3, border_mode='same', subsample=(1, 1)))
    cnn.add(LeakyReLU())
    cnn.add(Dropout(0.3))
    cnn.add(Convolution2D(128, 3, 3, border_mode='same', subsample=(1, 1)))
    cnn.add(LeakyReLU())
    cnn.add(Dropout(0.3))
    cnn.add(Convolution2D(256, 3, 3, border_mode='same', subsample=(1, 1)))
    cnn.add(LeakyReLU())
    cnn.add(Dropout(0.3))
    cnn.add(Flatten())
    image = Input(shape=(1, 28, 28))
    features = cnn(image)
    #有两个输出
    #输出真假值,范围在0~1
    fake = Dense(1, activation='sigmoid',name='generation')(features)
    #辅助分类器,输出图片分类
   aux = Dense(10, activation='softmax', name='auxiliary')(features)
    return Model(input=image, output=[fake, aux])

训练过程,50轮(epoch),把权重保存,每轮把虚假数据生成图处保存,观察虚假数据演化过程。

if __name__ =='__main__':
    #定义超参数
    nb_epochs = 50
    batch_size = 100
    latent_size = 100
    #优化器学习率
    adam_lr = 0.0002
    adam_beta_l = 0.5
    #构建判别网络
    discriminator = build_discriminator()
    discriminator.compile(optimizer=adam(lr=adam_lr, beta_l=adam_beta_l), loss='binary_crossentropy')
    latent = Input(shape=(lastent_size, ))
    image_class = Input(shape-(1, ), dtype='int32')
    #生成组合模型
    discriminator.trainable = False
    fake, aux = discriminator(fake)
    combined = Model(input=[latent, image_class], output=[fake, aux])
    combined.compile(optimizer=Adam(lr=adam_lr, beta_l=adam_beta_1), loss=['binary_crossentropy', 'sparse_categorical_crossentropy'])
    #将mnist数据转化为(...,1,28,28)维度,取值范围为[-1,1]
    (X_train,y_train),(X_test,y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=1)
    X_test = (X_test.astype(np.float32) - 127.5) / 127.5
    X_test = np.expand_dims(X_test, axis=1)
    num_train, num_test = X_train.shape[0], X_test.shape[0]
    train_history = defaultdict(list)
    test_history = defaultdict(list)
    for epoch in range(epochs):
        print('Epoch {} of {}'.format(epoch + 1, epochs))
        num_batches = int(X_train.shape[0] / batch_size)
        progress_bar = Progbar(target=num_batches)
        epoch_gen_loss = []
        epoch_disc_loss = []
        for index in range(num_batches):
            progress_bar.update(index)
            #产生一个批次的噪声数据
            noise = np.random.uniform(-1, 1, (batch_size, latent_size))
            # 获取一个批次的真实数据
            image_batch = X_train[index * batch_size:(index + 1) * batch_size]
            label_batch = y_train[index * batch_size:(index + 1) * batch_size]
            # 生成一些噪声标记
            sampled_labels = np.random.randint(0, 10, batch_size)
            # 产生一个批次的虚假图片
            generated_images = generator.predict(
            [noise, sampled_labels.reshape((-1, 1))], verbose=0)
            X = np.concatenate((image_batch, generated_images))
            y = np.array([1] * batch_size + [0] * batch_size)
            aux_y = np.concatenate((label_batch, sampled_labels), axis=0)
            epoch_disc_loss.append(discriminator.train_on_batch(X, [y, aux_y]))
            # 产生两个批次噪声和标记
            noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))
            sampled_labels = np.random.randint(0, 10, 2 * batch_size)
            # 训练生成模型来欺骗判别模型,输出真/假都设为真
            trick = np.ones(2 * batch_size)
            epoch_gen_loss.append(combined.train_on_batch(
               [noise, sampled_labels.reshape((-1, 1))],
                [trick, sampled_labels]))
        print('\nTesting for epoch {}:'.format(epoch + 1))
        # 评估测试集,产生一个新批次噪声数据
        noise = np.random.uniform(-1, 1, (num_test, latent_size))
        sampled_labels = np.random.randint(0, 10, num_test)
        generated_images = generator.predict(
            [noise, sampled_labels.reshape((-1, 1))], verbose=False)
        X = np.concatenate((X_test, generated_images))
        y = np.array([1] * num_test + [0] * num_test)
        aux_y = np.concatenate((y_test, sampled_labels), axis=0)
        # 判别模型是否能判别
        discriminator_test_loss = discriminator.evaluate(
            X, [y, aux_y], verbose=False)
        discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)
        # 创建两个批次新噪声数据
        noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))
        sampled_labels = np.random.randint(0, 10, 2 * num_test)
        trick = np.ones(2 * num_test)
        generator_test_loss = combined.evaluate(
            [noise, sampled_labels.reshape((-1, 1))],
            [trick, sampled_labels], verbose=False)
        generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)
        # 损失值等性能指标记录下来,并输出
        train_history['generator'].append(generator_train_loss)
        train_history['discriminator'].append(discriminator_train_loss)
        test_history['generator'].append(generator_test_loss)
        test_history['discriminator'].append(discriminator_test_loss)
        print('{0:<22s} | {1:4s} | {2:15s} | {3:5s}'.format(
            'component', *discriminator.metrics_names))
        print('-' * 65)
        ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.2f} | {3:<5.2f}'
        print(ROW_FMT.format('generator (train)',
                         *train_history['generator'][-1]))
        print(ROW_FMT.format('generator (test)',
                         *test_history['generator'][-1]))
        print(ROW_FMT.format('discriminator (train)',
                         *train_history['discriminator'][-1]))
        print(ROW_FMT.format('discriminator (test)',
                         *test_history['discriminator'][-1]))
        # 每个epoch保存一次权重
        generator.save_weights(
            'params_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
        discriminator.save_weights(
            'params_discriminator_epoch_{0:03d}.hdf5'.format(epoch), True)
        # 生成一些可视化虚假数字看演化过程
        noise = np.random.uniform(-1, 1, (100, latent_size))
        sampled_labels = np.array([
            [i] * 10 for i in range(10)
        ]).reshape(-1, 1)
        generated_images = generator.predict(
            [noise, sampled_labels], verbose=0)
        # 整理到一个方格
        img = (np.concatenate([r.reshape(-1, 28)
                           for r in np.split(generated_images, 10)
                           ], axis=-1) * 127.5 + 127.5).astype(np.uint8)
        Image.fromarray(img).save(
            'plot_epoch_{0:03d}_generated.png'.format(epoch))
    pickle.dump({'train': train_history, 'test': test_history},
                open('acgan-history.pkl', 'wb'))

训练结束,创建3类文件。params_discriminator_epoch_{{epoch_number}}.hdf5,判别模型权重参数。params_generator_epoch_{{epoch_number}}.hdf5,生成模型权重参数。plot_epoch_{{epoch_number}}_generated.png 。

生成式对抗网络改进。生成式对抗网络(generative adversarial network,GAN)在无监督学习非常有效。常规生成式对抗网络判别器使用Sigmoid交叉熵损失函数,学习过程梯度消失。Wasserstein生成式对抗网络(Wasserstein generative adversarial network,WGAN),使用Wasserstein距离度量,而不是Jensen-Shannon散度(Jensen-Shannon divergence,JSD)。使用最小二乘生成式对抗网络(least squares generative adversarial network,LSGAN),判别模型用最小平方损失小函数(least squares loss function)。Sebastian Nowozin、Botond Cseke、Ryota Tomioka论文《f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization》。

参考资料: 《TensorFlow技术解析与实战》

欢迎付费咨询(150元每小时),我的微信:qingxingfengzi

© 著作权归作者所有

共有 人打赏支持
利炳根
粉丝 11
博文 60
码字总数 136346
作品 0
深圳
《Adversarial Multi-task Learning for Text Classification》阅读笔记

来源:ACL 2017 链接:link 转载请注明出处:学习ML的皮皮虾 神经网络模型可以通过学习共享层来提取多任务的共同特征,然而共同特征容易被任务的特定特征和其它任务的噪声干扰。本文中,用1...

王明阳
2017/12/03
0
0

hjimce算法类博文目录 个人博客:http://blog.csdn.net/hjimce 个人qq:1393852684 知乎:https://www.zhihu.com/people/huang-jin-chi-28/activities 一、深度学习 深度学习(七十)darknet...

hjimce
2016/01/24
0
0
(zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058&m=4077873754872790&cu=5070353058 深度学习全网最全学习资料汇总之模型介绍篇 雷锋网 ......

wangxiaocvpr
2017/02/22
0
0
(转) 简述生成式对抗网络

简述生成式对抗网络 【转载请注明出处】chenrudan.github.io 本文主要阐述了对生成式对抗网络的理解,首先谈到了什么是对抗样本,以及它与对抗网络的关系,然后解释了对抗网络的每个组成部分...

wangxiaocvpr
2016/11/16
0
0
(转)能根据文字生成图片的 GAN,深度学习领域的又一新星

本文转自:https://mp.weixin.qq.com/s?biz=MzIwMTgwNjgyOQ==&mid=2247484846&idx=1&sn=c2333a9986c19e7106ae94d14a0555b9 能根据文字生成图片的 GAN,深度学习领域的又一新星 2017-01-12 D......

wangxiaocvpr
2017/01/13
0
0
深度学习:对抗网络框架,让机器在“竞争中自我成长”

  深度神经网络在判别模型领域的进步远比在生成模型领域进步快得多,其主要原因就在于相对于生成式模型来说,判别模型目标清晰、逻辑相对简单,实现起来容易。用通俗的比喻来说,判别模型相...

中国机器人
05/24
0
0
(转)【重磅】无监督学习生成式对抗网络突破,OpenAI 5大项目落地

【重磅】无监督学习生成式对抗网络突破,OpenAI 5大项目落地 【新智元导读】“生成对抗网络是切片面包发明以来最令人激动的事情!”LeCun前不久在Quroa答问时毫不加掩饰对生成对抗网络的喜爱...

wangxiaocvpr
2016/10/16
0
0
CVPR2018 | 海康、UCLA、北理联合提出3D DescriptorNet:可按条件生成3D形状,克服模式崩溃

  选自arXiv   作者:Jianwen Xie等   机器之心编译   参与:Huiyuan Zhuo、刘晓坤      近日,海康威视、UCLA、北理工联合提出了新的模型 3D DescriptorNet。该模型通过结合能量...

机器之心
04/11
0
0
业界 | Petuum提出深度生成模型统一的统计学框架

  选自Medium   作者:Zhiting Hu   机器之心编译   参与:刘晓坤、路、邹俏也      Petuum 和 CMU 合作的论文《On Unifying Deep Generative Models》提出深度生成模型的统一框...

机器之心
04/24
0
0
再读IRGAN,聊聊Code与Formulation的差异

之前,我在这篇笔记中稍微聊了一些IRGAN这个优秀的作品:Role of RL in Text Generation by GAN(强化学习在生成对抗网络文本生成中扮演的角色) 前些日子看到这样一个讨论IRGAN的问题:如何看...

胡杨
2017/10/03
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

面试系列-40个Java多线程问题总结

前言 这篇文章主要是对多线程的问题进行总结的,因此罗列了40个多线程的问题。 这些多线程的问题,有些来源于各大网站、有些来源于自己的思考。可能有些问题网上有、可能有些问题对应的答案也...

Ryan-瑞恩
12分钟前
0
0
微信分享的细节

分享的缩略图要求: 一、图片大小小于32k 二、图片的尺寸为 宽度 :128px 高度:128px 分享title 和 description 出现金额等 以上情况存在会导致触发分享按钮 但是页面没有反应...

Js_Mei
18分钟前
0
0
【2018.07.23学习笔记】【linux高级知识 Shell脚本编程练习】

1、编写shell脚本,计算1-100的和; #!/bin/bashsum=0for i in `seq 1 100`do sum=$[$sum+$i]doneecho $sum 2、编写shell脚本,要求输入一个数字,然后计算出从1到输入数字的和,要求...

lgsxp
20分钟前
0
0
xss攻防浅谈

导读 XSS (Cross-Site Script) 攻击又叫跨站脚本攻击, 本质是一种注入攻击. 其原理, 简单的说就是利用各种手段把恶意代码添加到网页中, 并让受害者执行这段脚本. XSS能做用户使用浏览器能做的...

吴伟祥
20分钟前
0
0
js回调的一次应用

function hideBtn(option) { if (option == 1) { $("#addBtn").hide(); $("#addSonBtn").hide(); }}$("body").on("click", "#selectBtn", function () {......

晨猫
26分钟前
0
0
C++_读写ini配置文件

1.WritePrivateProfileString:

一个小妞
27分钟前
0
0
通往阿里,BAT的50+经典Java面试题及答案解析(上)

Java是一个支持并发、基于类和面向对象的计算机编程语言。下面列出了面向对象软件开发的优点: 代码开发模块化,更易维护和修改。 代码复用。 增强代码的可靠性和灵活性。 增加代码的可理解性...

Java大蜗牛
27分钟前
1
0
数据库两大神器【索引和锁】

前言 只有光头才能变强 索引和锁在数据库中可以说是非常重要的知识点了,在面试中也会经常会被问到的。 本文力求简单讲清每个知识点,希望大家看完能有所收获 声明:如果没有说明具体的数据库...

Java3y
31分钟前
0
0
Application Express安装

Application Express安装文档 数据库选择和安装 数据库选择 Oracle建议直接12.2.0.1.0及以上的版本,12.1存在20618595bug(具体可参见官方文档) Oracle 12c 中安装oracle application expr...

youfen
43分钟前
0
0
OpenMessaging概览

序 本文主要研究一下OpenMessaging 架构图 namespace,类似cgroup的namespace,用来进行安全隔离,每个namespace有自己的producer、consumer、topic、queue等 producer,消息生产者有两类,一...

go4it
47分钟前
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部