文档章节

小白学Tensorflow之多层神经网络

AllenOR灵感
 AllenOR灵感
发布于 2017/09/10 01:23
字数 673
阅读 4
收藏 0

在本博客中,我们将利用Tensorflow来构建一个多层神经网络。因为本博客是为了学习目的,所以我们就来构建一个四层神经网络,即一个输入层,两个隐藏层,一个输出层。第一,我们需要定义层与层之间的转移矩阵

# 定义输入层到第一个隐藏层之间的连接矩阵
w_layer_1 = init_weights([784, 625])

# 定义第一个隐藏层到第二个隐藏层之间的连接矩阵
w_layer_2 = init_weights([625, 625])

# 定义第二个隐藏层到输出层之间的连接矩阵
w_layer_3 = init_weights([625, 10])

第二,构建模型。在此模型中,我们加入了dropout参数,该参数是为了防止过拟合。也就是说,如果在某一层中使用了dropout参数,那么该层只有一部分神经元放电。比如,dropout = 0.8,那么只有80%的神经元是出于放电状态的,其他都是关闭状态。

def model(X, w_layer_1, w_layer_2, w_layer_3, p_keep_input, p_keep_hidden): 
  X = tf.nn.dropout(X, p_keep_input) 
  hidden_1 = tf.nn.relu(tf.matmul(X, w_layer_1)) 
  hidden_1 = tf.nn.dropout(hidden_1, p_keep_hidden) 
  hidden_2 = tf.nn.relu(tf.matmul(hidden_1, w_layer_2)) 
  hidden_2 = tf.nn.dropout(hidden_2, p_keep_hidden) 
  return tf.matmul(hidden_2, w_layer_3)

完整代码,如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf 
import input_data

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev = 0.01))

def model(X, w_layer_1, w_layer_2, w_layer_3, p_keep_input, p_keep_hidden):
    X = tf.nn.dropout(X, p_keep_input)
    hidden_1 = tf.nn.relu(tf.matmul(X, w_layer_1))

    hidden_1 = tf.nn.dropout(hidden_1, p_keep_hidden)
    hidden_2 = tf.nn.relu(tf.matmul(hidden_1, w_layer_2))

    hidden_2 = tf.nn.dropout(hidden_2, p_keep_hidden)

    return tf.matmul(hidden_2, w_layer_3)

# 导入数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

X = tf.placeholder("float", [None, 784])
Y = tf.placeholder("float", [None, 10])

# 在该模型中我们一共有4层,一个输入层,两个隐藏层,一个输出层
# 定义输入层到第一个隐藏层之间的连接矩阵
w_layer_1 = init_weights([784, 625])

# 定义第一个隐藏层到第二个隐藏层之间的连接矩阵
w_layer_2 = init_weights([625, 625])

# 定义第二个隐藏层到输出层之间的连接矩阵
w_layer_3 = init_weights([625, 10])

# dropout 系数
# 定义有多少有效的神经元将作为输入神经元,比如 p_keep_intput = 0.8,那么只有80%的神经元将作为输入
p_keep_input = tf.placeholder("float")

# 定义有多少的有效神经元将在隐藏层被激活
p_keep_hidden = tf.placeholder("float")

# 构建模型
py_x = model(X, w_layer_1, w_layer_2, w_layer_3, p_keep_input, p_keep_hidden)


cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)

with tf.Session() as sess:

    init = tf.initialize_all_variables()
    sess.run(init)

    for i in xrange(100):
        for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
            sess.run(train_op, feed_dict = {X: trX[start:end], Y: trY[start:end],
                                            p_keep_input: 0.8, p_keep_hidden: 0.5})
        print i, np.mean(np.argmax(teY, axis = 1) == sess.run(predict_op, 
                        feed_dict = {X: teX, Y: teY, p_keep_input: 1.0, p_keep_hidden: 1.0}))

本文转载自:http://www.jianshu.com/p/f481b19031a6

共有 人打赏支持
AllenOR灵感
粉丝 10
博文 2634
码字总数 82983
作品 0
程序员
人人都会深度学习之Tensorflow基础快速入门

《Tensorflow基础快速入门》课程的目的是帮助广大的深度学习爱好者,逐层深入,步步精通当下最流行的深度学习框架Tensorflow。该课程包含Tensorflow运行原理,Tensor上面常见的操作,常见API...

liwei2000
07/05
0
0
送书&优惠丨对深度学习感兴趣的你,不了解这些就太OUT了!

点击上方“程序人生”,选择“置顶公众号” 第一时间关注程序猿(媛)身边的故事 TensorFlow是什么? TensorFlow的前身是谷歌大脑(google brain)团队研发的DistBelief。自创建以来,它便被...

csdnsevenn
05/03
0
0
TensorFlow——MNIST手写数字识别

MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/ 一、数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集被分成两部分:60000行的训练数据集...

飞天小橘子
04/27
0
0
史上最全TensorFlow学习资源汇总

来源 悦动智能(公众号ID:aibbtcom) 本篇文章将为大家总结TensorFlow纯干货学习资源,非常适合新手学习,建议大家收藏。 ▌一 、TensorFlow教程资源 1)适合初学者的TensorFlow教程和代码示...

悦动智能
04/12
0
0
tensorflow入门---第三章

tensorflow程序分为两个阶段: 第一个阶段:定义计算图所有的计算 第二个阶段:执行计算 第一节:计算模型—–计算图 第二节:数据模型—–张量 第三节:运行模型—–会话 第一节:计算图 计...

cttacm
05/05
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

python3.6 取余运算

python中取余运算逻辑如下: 如果a 与d 是整数,d 非零,那么余数 r 满足这样的关系: a = qd + r , q 为整数,且0 ≤ |r| < |d|。 经过测试可发现,python3.6中取余运算得到的 r 是正整数;...

colinux
13分钟前
1
0
[雪峰磁针石博客]软件测试专家工具包1web测试

web测试 本章主要涉及功能测试、自动化测试(参考: 软件自动化测试初学者忠告) 、接口测试(参考:10分钟学会API测试)、跨浏览器测试、可访问性测试和可用性测试的测试工具列表。 安全测试工具...

python测试开发人工智能安全
今天
3
0
JS:异步 - 面试惨案

为什么会写这篇文章,很明显不符合我的性格的东西,原因是前段时间参与了一个面试,对于很多程序员来说,面试时候多么的鸦雀无声,事后心里就有多么的千军万马。去掉最开始毕业干了一年的Jav...

xmqywx
今天
3
0
Win10 64位系统,PHP 扩展 curl插件

执行:1. 拷贝php安装目录下,libeay32.dll、ssleay32.dll 、 libssh2.dll 到 C:\windows\system32 目录。2. 拷贝php/ext目录下, php_curl.dll 到 C:\windows\system32 目录; 3. p...

放飞E梦想O
今天
1
0
谈谈神秘的ES6——(五)解构赋值【对象篇】

上一节课我们了解了有关数组的解构赋值相关内容,这节课,我们接着,来讲讲对象的解构赋值。 解构不仅可以用于数组,还可以用于对象。 let { foo, bar } = { foo: "aaa", bar: "bbb" };fo...

JandenMa
今天
1
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部