文档章节

在浏览器中进行深度学习:TensorFlow.js (八)生成对抗网络 (GAN)

naughty
 naughty
发布于 10/18 03:05
字数 1398
阅读 62
收藏 0

Generative Adversarial Network 是深度学习中非常有趣的一种方法。GAN最早源自Ian Goodfellow的这篇论文LeCun对GAN给出了极高的评价:

“There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.” – Yann LeCun

那么我们就看看GAN究竟是怎么回事吧:

如上图所示,GAN包含两个互相对抗的网络:G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

  • Generator是一个生成器的网络,它接收一个随机的噪声,通过这个噪声生成图片,记做G(z)。
  • Discriminator是一个鉴别器网络,判别一张图片或者一个输入是不是“真实的”。它的输入x是数据或者图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

最后,我们就可以使用生成器和随机输入来生成不同的数据或者图片了。

上面的描述大家可能都能理解,但是把它变成数学语言,可能你就蒙B了。

“GAN的核心原理”的图片搜索结果

如上图所示,x是输入,z是随机噪声。D(x)是鉴别器的判定数据为真的概率,D(G(z))是判定生成数据为真的概率。生成器希望这个D(G(z))越大越好,这个时候整个表达式的值应该变小。而鉴别器的目的是能够有效区分真实数据和假数据,所以D(x)应该趋向于变大,D(G(z))趋向于变小,整个表达式就变大。也就是说训练过程,生成器和辨别器互相对抗,一个使上述表达式变小,另一个使其变大,最后训练趋向于平衡,而生成器这时候应该生成真假难辨的数据,这就是我们的最终目的。

上图是GAN算法训练的具体过程,这里我们不做过多的解释,直接运行一个例子。

“GAN”的图片搜索结果

我们用MINST数据集来看看如何使用TensorflowJS来训练一个GAN,模拟生成手写数字。

代码见我的codepen

function gen(xs) {
  const l1 = tf.leakyRelu(xs.matMul(G1w).add(G1b));
  const l2 = tf.leakyRelu(l1.matMul(G2w).add(G2b));
  const l3 = tf.tanh(l2.matMul(G3w).add(G3b));
  return l3;
}

function disReal(xs) {
  const l1 = tf.leakyRelu(xs.matMul(D1w).add(D1b));
  const l2 = tf.leakyRelu(l1.matMul(D2w).add(D2b));
  const logits = l2.matMul(D3w).add(D3b);
  const output = tf.sigmoid(logits);
  return [logits, output];
}

function disFake(xs) {
  return disReal(gen(xs));
}

GAN的两个网络分别用gen和disReal创建。gen是生成器网络,disReal是辨别器的网络。disFake是把生成数据用辨别器来辨别。这里的网络使用leakyrelu。使得输出在-inf到+inf,利用sigmoid映射到【0,1】,这是辨别器模型输出一个0-1之间的概率。

“leaky relu”的图片搜索结果

 

通常我们会创建一个比生成器更复杂的鉴别器网络使得鉴别器有足够的分辨能力。但在这个例子里,两个网络的复杂程度类似。

计算损失的函数使用 tf.sigmoidCrossEntropyWithLogits,值得注意的是,在最新的0.13版本中,这个交叉熵被移除了,你需要自己实现该方法。

训练过程如下:

async function trainBatch(realBatch, fakeBatch) {
  const dcost = dOptimizer.minimize(
    () => {
      const [logitsReal, outputReal] = disReal(realBatch);
      const [logitsFake, outputFake] = disFake(fakeBatch);

      const lossReal = sigmoidCrossEntropyWithLogits(ONES_PRIME, logitsReal);
      const lossFake = sigmoidCrossEntropyWithLogits(ZEROS, logitsFake);
      return lossReal.add(lossFake).mean();
    },
    true,
    [D1w, D1b, D2w, D2b, D3w, D3b]
  );
  await tf.nextFrame();
  const gcost = gOptimizer.minimize(
    () => {
      const [logitsFake, outputFake] = disFake(fakeBatch);

      const lossFake = sigmoidCrossEntropyWithLogits(ONES, logitsFake);
      return lossFake.mean();
    },
    true,
    [G1w, G1b, G2w, G2b, G3w, G3b]
  );
  await tf.nextFrame();

  return [dcost, gcost];
}

训练使用了两个optimizer,

  1. 第一步,计算实际数据的辨别结果和1的交叉熵,以及生成器生成数据的辨别结果和0的交叉熵。也就是说,我们希望辨别器尽可能的判断出生成数据都是假的而实际数据都是真的。使得这两个交叉熵的均值最小。
  2. 第二步开始对抗,要让生成数据尽可能被判别为真。

下图是某个训练过程的损失:

这个是经过1000个迭代后的生成图:

大家可以尝试调整学习率,增加网络复杂度,加大迭代次数来获得更好的生成模型。

GAN的学习其实还是比较复杂的,参数和损失选择都不容易,好在有一些现成的工具可以使用,另外推荐大家去https://poloclub.github.io/ganlab/,提供了很直观的GAN学习的过程。这个也是用TensorflowJS来实现的。

参考:

© 著作权归作者所有

共有 人打赏支持
naughty
粉丝 264
博文 62
码字总数 112619
作品 0
其它
架构师
私信 提问
GAN要取代深度学习了?请不要慌!

计算机视觉顶会盛会CVPR 2018召开在即,从官方现在接收的论文类型来看,这届会议展现出了一个奇怪的现象:生成对抗网络GAN,正在成为新的“深度学习”。MMP,深度学习还没学会,难道我又要被...

【方向】
06/09
0
0
深度学习(五十四)图片翻译WGAN实验测试

版权声明:本文为博主原创文章,欢迎转载,转载请注明原文地址、作者信息。 https://blog.csdn.net/hjimce/article/details/60346089 图片翻译WGAN实验测试 博客:http://blog.csdn.net/hjim...

hjimce
2017/03/04
0
0
【Python】利用GAN生成MNIST数据集

本文转载至知乎ID:Charles(白露未晞)知乎个人专栏 导语 利用Python搭建简单的GAN网络来生成MNIST数据集。其中GAN,即生成对抗网络。 英文全称: Generative Adversarial Networks 偷闲入门...

W3Cschool小编
08/07
0
0
用Keras搭建GAN:图像去模糊中的应用(附代码)

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题GAN with Keras: Application to Image Deblurring,作者为Raphaël Meudec。 翻译 | 廖颖 陈俊雅 整理 | 凡江 2014年 Ia...

雷锋字幕组
04/25
0
0
让机器“析毫剖厘”:图像理解与编辑|VALSE2018之三

编者按:李白在《秋登宣城谢脁北楼》中曾写道: “江城如画里,山晓望晴空。 两水夹明镜,双桥落彩虹。” 通过对视野内景物位置关系的描写,一幅登高远眺的秋色美景图宛在眼前。而在计算机视...

xwukefr2tnh4
05/09
0
0

没有更多内容

加载失败,请刷新页面

加载更多

利用cefSharp实现网页自动注册登录的需要注册的一些事项

最近朋友有个需要自动注册登录点击的事,我帮着写了写,好久没写过这东西了,在写的过程中总结了需要注意的一些事项。 一、换IP之后要测试一下速度,我目前用的最简单的测试方法就是20-30秒加...

我退而结网
15分钟前
1
0
Go语言中使用 BoltDB数据库

boltdb 是使用Go语言编写的开源的键值对数据库,Github的地址如下: https://github.com/boltdb/bolt boltdb 存储数据时 key 和 value 都要求是字节数据,此处需要使用到 序列化和反序列化。...

Oo若离oO
16分钟前
1
0
zookeeper分布式锁

//lock 锁 定义分布式锁public interface Lock {//获取锁public void getLock();//释放锁public void unLock();} public abstract class ZookeeperAbstractLock implements Loc......

熊猫你好
23分钟前
0
0
mysql_事务隔离机制

事务隔离机制 事务就是要保证一组数据库操作,要么全部成功,要么全部失败。在mysql中,事务支持是在引擎层实现的。mysql是一个支持多引擎的系统,但并不是所有引擎都支持事务,比如mysql...

grace_233
25分钟前
0
0
不学无数——Java中IO和NIO

JAVA中的I/O和NIO I/O 问题是任何编程语言都无法回避的问题,可以说 I/O 问题是整个人机交互的核心问题,因为 I/O 是机器获取和交换信息的主要渠道。在当今这个数据大爆炸时代,I/O 问题尤其...

不学无数的程序员
31分钟前
0
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部