文档章节

如何用Deeplearning4j实现GAN

 冷血狂魔
发布于 08/14 13:10
字数 1065
阅读 920
收藏 7

一、Gan的思想

    Gan的核心所做的事情是在解决一个argminmax的问题,公式:

    1、求解一个Discriminator,可以最大尺度的丈量Generator 产生的数据和真实数据之间的分布距离

    2、求解一个Generator,可以最大程度减小产生数据和真实数据之间的距离

    gan的原始公式如下:

    实际上,我们不可能真求期望,只能sample出data来近似求解,于是,公式变成如下:

    于是,求解V的最大值,变成了一个二分类问题,变成了求交叉熵的最小值。

二、代码

public class Gan {
	static double lr = 0.01;

	public static void main(String[] args) throws Exception {

		final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER);

		final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");

		ComputationGraph net = new ComputationGraph(graphBuilder.build());
		net.init();
		System.out.println(net.summary());
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		net.setListeners(new ScoreIterationListener(100));
		net.getLayers();
		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
	
		INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));

		INDArray labelG = Nd4j.ones(60, 1);

		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();
			INDArray z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution());
			MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelD });
			for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution());
			MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelG });
			trainG(net, dataSetG);

			if (i % 10000 == 0) {
			   net.save(new File("E:/gan.zip"), true);
			}

		}

	}

	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}

	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
}

    说明:

    1、dl4j并没有提供像keras那样冻结某些层参数的方法,这里采用设置learningrate为0的方法,来冻结某些层的参数

    2、这个的更新器,用的是sgd,不能用其他的(比方说Adam、Rmsprop),因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度,达不到不更新参数的目的

    3、这里用了StackVertex,沿着第一维合并张量,也就是合并真实数据样本和Generator产生的数据样本,共同训练Discriminator

    4、训练过程中多次update   Discriminator的参数,以便量出最大距离,让后更新Generator一次

    5、进行10w次迭代

三、Generator生成手写数字

    加载训练好的模型,随机从NormalDistribution取出一些噪音数据,丢给模型,经过feedForward,取出最后一层Generator的激活值,便是我们想要的结果,代码如下:

public class LoadGan {

	public static void main(String[] args) throws Exception {
	    ComputationGraph restored = ComputationGraph.load(new File("E:/gan.zip"), true);
		
		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		INDArray trueExp = train.next().getFeatures();
		Map<String, INDArray> map = restored.feedForward(
				new INDArray[] { Nd4j.rand(new long[] { 50, 10 }, new NormalDistribution()), trueExp }, false);
		INDArray indArray = map.get("g3");// .reshape(20,28,28);
		List<INDArray> list = new ArrayList<>();
		for (int j = 0; j < indArray.size(0); j++) {
			list.add(indArray.getRow(j));
		}
	    
		MNISTVisualizer bestVisualizer = new MNISTVisualizer(1, list, "Gan");

		bestVisualizer.visualize();
	}
	
	
	public static class MNISTVisualizer {
		private double imageScale;
		private List<INDArray> digits; // Digits (as row vectors), one per
										// INDArray
		private String title;
		private int gridWidth;

		public MNISTVisualizer(double imageScale, List<INDArray> digits, String title) {
			this(imageScale, digits, title, 5);
		}

		public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth) {
			this.imageScale = imageScale;
			this.digits = digits;
			this.title = title;
			this.gridWidth = gridWidth;
		}

		public void visualize() {
			JFrame frame = new JFrame();
			frame.setTitle(title);
			frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

			JPanel panel = new JPanel();
			panel.setLayout(new GridLayout(0, gridWidth));

			List<JLabel> list = getComponents();
			for (JLabel image : list) {
				panel.add(image);
			}

			frame.add(panel);
			frame.setVisible(true);
			frame.pack();
		}

		public List<JLabel> getComponents() {
			List<JLabel> images = new ArrayList<>();
			for (INDArray arr : digits) {
				BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
				for (int i = 0; i < 784; i++) {
					bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i)));
				}
				ImageIcon orig = new ImageIcon(bi);
				Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28),
						Image.SCALE_DEFAULT);
				ImageIcon scaled = new ImageIcon(imageScaled);
				images.add(new JLabel(scaled));
			}
			return images;
		}
	}
}

    实际效果,还算比较清晰

 

 

快乐源于分享。

   此博客乃作者原创, 转载请注明出处

© 著作权归作者所有

粉丝 108
博文 47
码字总数 56986
作品 0
杭州
程序员
私信 提问
(zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058&m=4077873754872790&cu=5070353058 深度学习全网最全学习资料汇总之模型介绍篇 雷锋网 ......

wangxiaocvpr
2017/02/22
0
0
Java 工程师快速入门深度学习,就从 Deeplearning4j 开始

作者:万宫玺 随着机器学习、深度学习为主要代表的人工智能技术的逐渐成熟,越来越多的 AI 产品得到了真正的落地。无论是以语音识别和自然语言处理为基础的个人助理软件,还是以人脸识别为基...

GitChat的博客
2018/12/13
0
0
用Keras搭建GAN:图像去模糊中的应用(附代码)

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题GAN with Keras: Application to Image Deblurring,作者为Raphaël Meudec。 翻译 | 廖颖 陈俊雅 整理 | 凡江 2014年 Ia...

雷锋字幕组
2018/04/25
0
0
分布式深度学习库--Deeplearning4j

Deeplearning4j(简称DL4J)是为Java和Scala编写的首个商业级开源分布式深度学习库。DL4J与Hadoop和Spark集成,为商业环境(而非研究工具目的)所设计。Skymind是DL4J的商业支持机构。 Deep...

匿名
2016/04/21
17.3K
11
资深算法工程师万宫玺:Java工程师转型AI的秘密法宝——深度学习框架Deeplearning4j | 分享总结

雷锋网AI研习社按:深度学习是人工智能发展最为迅速的领域之一,Google、Facebook、Microsoft等巨头都围绕深度学习重点投资了一系列新兴项目,他们也一直在支持一些开源深度学习框架。目前研...

杨文
2018/01/02
0
0

没有更多内容

加载失败,请刷新页面

加载更多

计算机实现原理专题--二进制减法器(二)

在计算机实现原理专题--二进制减法器(一)中说明了基本原理,现准备说明如何来实现。 首先第一步255-b运算相当于对b进行按位取反,因此可将8个非门组成如下图的形式: 由于每次做减法时,我...

FAT_mt
今天
5
0
好程序员大数据学习路线分享函数+map映射+元祖

好程序员大数据学习路线分享函数+map映射+元祖,大数据各个平台上的语言实现 hadoop 由java实现,2003年至今,三大块:数据处理,数据存储,数据计算 存储: hbase --> 数据成表 处理: hive --> 数...

好程序员官方
今天
7
0
tabel 中含有复选框的列 数据理解

1、el-ui中实现某一列为复选框 实现多选非常简单: 手动添加一个el-table-column,设type属性为selction即可; 2、@selection-change事件:选项发生勾选状态变化时触发该事件 <el-table @sel...

everthing
今天
6
0
【技术分享】TestFlight测试的流程文档

上架基本需求资料 1、苹果开发者账号(如还没账号先申请-苹果开发者账号申请教程) 2、开发好的APP 通过本篇教程,可以学习到ios证书申请和打包ipa上传到appstoreconnect.apple.com进行TestF...

qtb999
今天
10
0
再见 Spring Boot 1.X,Spring Boot 2.X 走向舞台中心

2019年8月6日,Spring 官方在其博客宣布,Spring Boot 1.x 停止维护,Spring Boot 1.x 生命周期正式结束。 其实早在2018年7月30号,Spring 官方就已经在博客进行过预告,Spring Boot 1.X 将维...

Java技术剑
今天
18
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部