文档章节

Keras之03-用MNIST数据集训练一个CNN

LevineHuang
 LevineHuang
发布于 2017/02/25 11:28
字数 841
阅读 1268
收藏 1

Keras之03-用MNIST数据集训练一个CNN


模型code

# -*- coding: utf-8 -*-

'''Trains a simple convnet on the MNIST dataset.

Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''

from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
# Keras的底层库使用Theano或TensorFlow
from keras import backend as K

batch_size = 128
nb_classes = 10
nb_epoch = 12

# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
pool_size = (2, 2)
# convolution kernel size
kernel_size = (3, 3)

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# 在如何表示一组彩色图片的问题上,Theano和TensorFlow发生了分歧.
# ’th’模式,也即Theano模式会把100张RGB三通道的16×32(高为16宽为32)彩色图表示为下面这种形式(100,3,16,32),Caffe采取的也是这种方式。第0个维度是样本维,代表样本的数目,第1个维度是通道维,代表颜色通道数。后面两个就是高和宽了。
# 而TensorFlow,即’tf’模式的表达形式是(100,16,32,3),即把通道维放在了最后。

# 根据backend模式reshape输入数据
if K.image_dim_ordering() == 'th':
    X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
    X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

model = Sequential()

# 卷积层    
# 二维卷积层对二维输入进行滑动窗卷积
# keras.layers.convolutional.Convolution2D(nb_filter, nb_row, nb_col, init='glorot_uniform', activation='linear', weights=None, border_mode='valid', subsample=(1, 1), dim_ordering='th', W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, bias=True)

# nb_filter:卷积核的数目,(即输出的维度)
# nb_row:卷积核的行数
# nb_col:卷积核的列数
# border_mode:边界模式,为“valid”,“same”或“full”,full需要以theano为后端

model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1],
                        border_mode='valid',
                        input_shape=input_shape))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
model.add(Activation('relu'))

# keras.layers.convolutional.MaxPooling2D(pool_size=(2, 2), strides=None, border_mode='valid', dim_ordering='th')
# 空域信号施加最大值池化
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Dropout(0.25))

# Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])

model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
          verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

模型运行结果

Using TensorFlow backend.
X_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 46s - loss: 0.3732 - acc: 0.8859 - val_loss: 0.0886 - val_acc: 0.9719
Epoch 2/12
60000/60000 [==============================] - 45s - loss: 0.1350 - acc: 0.9597 - val_loss: 0.0627 - val_acc: 0.9796
Epoch 3/12
60000/60000 [==============================] - 45s - loss: 0.1027 - acc: 0.9697 - val_loss: 0.0562 - val_acc: 0.9822
Epoch 4/12
60000/60000 [==============================] - 45s - loss: 0.0884 - acc: 0.9741 - val_loss: 0.0438 - val_acc: 0.9858
Epoch 5/12
60000/60000 [==============================] - 45s - loss: 0.0779 - acc: 0.9772 - val_loss: 0.0415 - val_acc: 0.9867
Epoch 6/12
60000/60000 [==============================] - 46s - loss: 0.0709 - acc: 0.9786 - val_loss: 0.0379 - val_acc: 0.9869
Epoch 7/12
60000/60000 [==============================] - 45s - loss: 0.0650 - acc: 0.9811 - val_loss: 0.0360 - val_acc: 0.9889
Epoch 8/12
60000/60000 [==============================] - 45s - loss: 0.0609 - acc: 0.9813 - val_loss: 0.0354 - val_acc: 0.9883
Epoch 9/12
60000/60000 [==============================] - 45s - loss: 0.0557 - acc: 0.9838 - val_loss: 0.0330 - val_acc: 0.9885
Epoch 10/12
60000/60000 [==============================] - 45s - loss: 0.0541 - acc: 0.9836 - val_loss: 0.0318 - val_acc: 0.9897
Epoch 11/12
60000/60000 [==============================] - 45s - loss: 0.0497 - acc: 0.9857 - val_loss: 0.0322 - val_acc: 0.9897
Epoch 12/12
60000/60000 [==============================] - 45s - loss: 0.0476 - acc: 0.9856 - val_loss: 0.0327 - val_acc: 0.9893
Test score: 0.0326897691154
Test accuracy: 0.9893

© 著作权归作者所有

LevineHuang

LevineHuang

粉丝 5
博文 9
码字总数 14787
作品 0
东城
私信 提问
易用的深度学习框架Keras简介及使用

Keras是基于Python的一个深度学习框架,内核采用Theano和Tensorflow,可以进行切换。 1. Keras简介 Keras是基于Theano的一个深度学习框架,它的设计参考了Torch,用Python语言编写,是一个高...

openthings
2016/01/10
1K
0
Bengio终结Theano不是偶然,其性能早在Keras支持的四大框架中垫底

作者 | Jasmeet Bhatia 编译 | KK4SBB 本文将对目前流行的几种Keras支持的深度学习框架性能做一次综述性对比,包括Tensorflow、CNTK、MXNet和Theano。作者Jasmeet Bhatia是微软的数据与人工智...

AI科技大本营
2017/10/12
0
0
车牌识别-Mask_RCNN定位车牌+手写方法分割字符+CNN单个字符识别

simple-car-plate-recognition 简单车牌识别-Mask_RCNN定位车牌+手写方法分割字符+CNN单个字符识别 数据准备 准备用于车牌定位的数据集,要收集250张车辆图片,200张用于训练,50张用于测试,...

airxiechao
2018/10/10
3.7K
0
我们建了个模型,搞定了 MNIST 数字识别任务

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题 A simple 2D CNN for MNIST digit recognition ,作者为 Sambit Mahapatra 。 翻译 | 王祎 霍雷刚 整理 | MY 对于图像分...

雷锋字幕组
2018/07/09
0
0
用简单的 2D CNN 进行 MNIST 数字识别

雷锋网 AI 研习社按:本文为雷锋网(公众号:雷锋网)字幕组编译的技术博客,原标题 A simple 2D CNN for MNIST digit recognition,作者为 Sambit Mahapatra。 翻译 | 王祎 校对 | 霍雷刚 整理...

雷锋字幕组
2018/07/23
0
0

没有更多内容

加载失败,请刷新页面

加载更多

OSChina 周日乱弹 —— 我,小小编辑,食人族酋长

Osc乱弹歌单(2019)请戳(这里) 【今日歌曲】 @宇辰OSC :分享娃娃的单曲《飘洋过海来看你》: #今日歌曲推荐# 《飘洋过海来看你》- 娃娃 手机党少年们想听歌,请使劲儿戳(这里) @宇辰OSC...

小小编辑
今天
219
9
MongoDB系列-- SpringBoot 中对 MongoDB 的 基本操作

SpringBoot 中对 MongoDB 的 基本操作 Database 库的创建 首先 在MongoDB 操作客户端 Robo 3T 中 创建数据库: 增加用户User: 创建 Collections 集合(类似mysql 中的 表): 后面我们大部分都...

TcWong
今天
2
0
spring cloud

一、从面试题入手 1.1、什么事微服务 1.2、微服务之间如何独立通讯的 1.3、springCloud和Dubbo有哪些区别 1.通信机制:DUbbo基于RPC远程过程调用;微服务cloud基于http restFUL API 1.4、spr...

榴莲黑芝麻糊
今天
2
0
Executor线程池原理与源码解读

线程池为线程生命周期的开销和资源不足问题提供了解决方 案。通过对多个任务重用线程,线程创建的开销被分摊到了多个任务上。 线程实现方式 Thread、Runnable、Callable //实现Runnable接口的...

小强的进阶之路
昨天
6
0
maven 环境隔离

解决问题 即 在 resource 文件夹下面 ,新增对应的资源配置文件夹,对应 开发,测试,生产的不同的配置内容 <resources> <resource> <directory>src/main/resources.${deplo......

之渊
昨天
8
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部