文档章节

Denoising Autoencoder

AllenOR灵感
 AllenOR灵感
发布于 2017/09/10 01:28
字数 690
阅读 11
收藏 1


降噪自编码器(DAE)是另一种自编码器的变种。强烈推荐 Pascal Vincent 的论文,该论文很详细的描述了该模型。降噪自编码器认为,设计一个能够恢复原始信号的自编码器未必是最好的,而能够对 “被污染/破坏” 的原始数据进行编码、解码,然后还能恢复真正的原始数据,这样的特征才是好的。

从数学上来讲,假设原始数据 x 被我们“故意破坏”了,比如加入高斯噪声,或者把某些维度数据抹掉,变成 x',然后在对 x' 进行编码、解码,得到回复信号 xx = g(f(x')) 。该恢复信号尽可能的逼近未被污染的原数据 x 。此时,监督训练的误差函数就从原来的 L(x, g(f(x))) 变成了 L(x, g(f(x')))

从直观上理解,降噪自编码器希望学到的特征尽可能鲁棒,能够在一定程度上对抗原始数据的污染、缺失等情况。Vincent 论文里也对 DAE 提出了基于流行学习的解释,并且在图像数据上进行测试,发现 DAE 能够学出类似 Gabor 边缘提取的特征变换。

DAE 的系统结构如下图所示:


现在使用比较多的噪声主要是 mask noise,即原始数据中部分数据缺失,这是有着强烈的实际意义的,比如图像部分像素被遮挡、文本因记录原因漏掉一些单词等等。

实现代码如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf 
import numpy as np 
import input_data

N_INPUT = 784
N_HIDDEN = 100
N_OUTPUT = N_INPUT
corruption_level = 0.3
epoches = 1000

def main(_):

    w_init = np.sqrt(6. / (N_INPUT + N_HIDDEN))
    weights = {
        "hidden": tf.Variable(tf.random_uniform([N_INPUT, N_HIDDEN], minval = -w_init, maxval = w_init)),
        "out": tf.Variable(tf.random_uniform([N_HIDDEN, N_OUTPUT], minval = -w_init, maxval = w_init))
    }

    bias = {
        "hidden": tf.Variable(tf.random_uniform([N_HIDDEN], minval = -w_init, maxval = w_init)),
        "out": tf.Variable(tf.random_uniform([N_OUTPUT], minval = -w_init, maxval = w_init))
    }

    with tf.name_scope("input"):
        # input data
        x = tf.placeholder("float", [None, N_INPUT])
        mask = tf.placeholder("float", [None, N_INPUT])

    with tf.name_scope("input_layer"):
        # from input data to input layer
        input_layer = tf.mul(x, mask)

    with tf.name_scope("hidden_layer"):
        # from input layer to hidden layer
        hidden_layer = tf.sigmoid(tf.add(tf.matmul(input_layer, weights["hidden"]), bias["hidden"]))

    with tf.name_scope("output_layer"):
        # from hidden layer to output layer
        output_layer = tf.sigmoid(tf.add(tf.matmul(hidden_layer, weights["out"]), bias["out"]))

    with tf.name_scope("cost"):
        # cost function
        cost = tf.reduce_sum(tf.pow(tf.sub(output_layer, x), 2))

    optimizer = tf.train.AdamOptimizer().minimize(cost)

    # load MNIST data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

    with tf.Session() as sess:

        init = tf.initialize_all_variables()
        sess.run(init)

        for i in range(epoches):
            for start, end in zip(range(0, len(trX), 100), range(100, len(trX), 100)):
                input_ = trX[start:end]
                mask_np = np.random.binomial(1, 1 - corruption_level, input_.shape)
                sess.run(optimizer, feed_dict={x: input_, mask: mask_np})

            mask_np = np.random.binomial(1, 1 - corruption_level, teX.shape)
            print i, sess.run(cost, feed_dict={x: teX, mask: mask_np})

if __name__ == "__main__":
    tf.app.run()

Reference:

《Extracting and Composing Robust Features with Denoising
Autoencoders》

《Stacked Denoising Autoencoders: Learning Useful Representations in a Deep Network with a Local Denoising Criterion》

本文转载自:http://www.jianshu.com/p/f7b9f10b7ac4

共有 人打赏支持
AllenOR灵感
粉丝 11
博文 2635
码字总数 83001
作品 0
程序员
私信 提问
TensorFlow 卷积自编码和去噪自编码

卷积网络的自编码 编码使用卷积核池化操作 解码使用反卷积和反池化操作 去噪自编码Denoising Autoencoder 要想取得好的特征只靠重构输入数据是不够的,实际应用中,还需要让这些特征具有抗干...

阿豪boy
08/08
0
0
机器学习笔记-Deep Learning

林轩田机器学习技法关于特征学习系列,其中涉及到 , , , , , , - , 等。 机器学习笔记-Neural Network 机器学习笔记-Deep Learning 机器学习笔记-Radial Basis Function Network 机器...

robin_Xu_shuai
2017/12/18
0
0
Deep Learning(深度学习)学习笔记整理系列之(四)

目录: 一、概述 二、背景 三、人脑视觉机理 四、关于特征 4.1、特征表示的粒度 4.2、初级(浅层)特征表示 4.3、结构性特征表示 4.4、需要有多少个特征? 五、Deep Learning的基本思想 六、...

云栖希望。
2017/12/04
0
0
深度学习,如何用去噪自编码器预测原始数据?

去噪自编码器(denoising autoencoder, DAE)是一类接受损坏数据作为输入,并训练来预测原始未被损坏数据作为输出的自编码器。 去噪自编码器代价函数的计算图。去噪自编码器被训练为从损坏的...

zlw东南风
2017/12/25
0
0
百度AI3.0让传统质检升级为人工智能,下一个失业的是谁?

当前制造业产品外表检查主要有人工质检和机器视觉质检两种方式,其中人工占90%,机器只占10%,而两者都面临许多挑战。人工质检成本高、误操作多、生产数据无法有效留存,机器视觉质检虽然不存...

安卓绿色联盟
11/27
0
0

没有更多内容

加载失败,请刷新页面

加载更多

mac 下 mysql 8.0.13 安装并记录遇到的问题 以便以后查看

安装 官网mysql 下载地址 安装过程 省去 安装好之后 下载navicat 错误1 链接 遇到 mysql 2003 - Can't connect to MySQL server 错误, 解决方案 重启mysql 服务 #错误2 ERROR 1045: Acces...

杭州-IT攻城狮
34分钟前
3
0

中国龙-扬科
38分钟前
1
0
[Spring4.x]基于spring4.x纯注解的Web工程搭建

在前文中已经说明了如何基于 Spring4.x+ 版本开发纯注解的非web项目,链接如下: https://my.oschina.net/morpheusWB/blog/2985600 本文则主要说明,如何在Web项目中,"基于spring纯注解方式...

morpheusWB
今天
13
0
基础编程题目集-7-13 日K蜡烛图

股票价格涨跌趋势,常用蜡烛图技术中的K线图来表示,分为按日的日K线、按周的周K线、按月的月K线等。以日K线为例,每天股票价格从开盘到收盘走完一天,对应一根蜡烛小图,要表示四个价格:开...

niithub
今天
5
0
Jenkins window 下的安装使用

1.下载:https://jenkins.io/download/ 双击安装完毕,将自动打开浏览器: http://localhost:8080 打开对应位置的文件,将初始密钥粘贴至输入框。 第一个是 安装默认的软件;第二个是 自定义...

狼王黄师傅
今天
7
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部