文档章节

基于PyTorch实现MNIST手写字识别

o
 osc_zoa3moe9
发布于 2019/12/08 14:20
字数 772
阅读 12
收藏 0

精选30+云产品,助力企业轻松上云!>>>

本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;

  • CNN:卷积神经网络结构;

  • LSTM:长短期记忆网络结构;

  • TrainProcess

    模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。

完整代码

Talk is cheap, show me the code.

# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime


class Config:
    batch_size = 64
    epoch = 10
    alpha = 1e-3

    print_per_step = 100  # 控制输出


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        """
        Conv2d参数:
        第一位:input channels  输入通道数
        第二位:output channels 输出通道数
        第三位:kernel size 卷积核尺寸
        第四位:stride 步长,默认为1
        第五位:padding size 默认为0,不补
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),  # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
            nn.ReLU()
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.output = nn.Linear(64, 10)

    def forward(self, x):
        r_out, (_, _) = self.lstm(x, None)

        out = self.output(r_out[:, -1, :])
        return out


class TrainProcess:

    def __init__(self, model="CNN"):
        self.train, self.test = self.load_data()
        self.model = model
        if self.model == "CNN":
            self.net = CNN()
        elif self.model == "LSTM":
            self.net = LSTM()
        else:
            raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
        self.criterion = nn.CrossEntropyLoss()  # 定义损失函数
        self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha)

    @staticmethod
    def load_data():
        print("Loading Data......")
        """加载MNIST数据集,本地数据不存在会自动下载"""
        train_data = datasets.MNIST(root='./data/',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_data = datasets.MNIST(root='./data/',
                                   train=False,
                                   transform=transforms.ToTensor())

        # 返回一个数据迭代器
        # shuffle:是否打乱顺序
        train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                   batch_size=Config.batch_size,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                                  batch_size=Config.batch_size,
                                                  shuffle=False)
        return train_loader, test_loader

    def train_step(self):
        steps = 0
        start_time = datetime.now()

        print("Training & Evaluating based on '%s'......" % self.model)
        for epoch in range(Config.epoch):
            print("Epoch {:3}.".format(epoch + 1))

            for data, label in self.train:
                data, label = Variable(data.cpu()), Variable(label.cpu())
                # LSTM输入为3维,CNN输入为4维
                if self.model == "LSTM":
                    data = data.view(-1, 28, 28)
                self.optimizer.zero_grad()  # 将梯度归零
                outputs = self.net(data)  # 将数据传入网络进行前向运算
                loss = self.criterion(outputs, label)  # 得到损失函数
                loss.backward()  # 反向传播
                self.optimizer.step()  # 通过梯度做一步参数更新

                # 每100次打印一次结果
                if steps % Config.print_per_step == 0:
                    _, predicted = torch.max(outputs, 1)
                    correct = int(sum(predicted == label))  # 计算预测正确个数
                    accuracy = correct / Config.batch_size  # 计算准确率
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
                    print(msg.format(steps, loss, accuracy, time_usage))

                steps += 1

        test_loss = 0.
        test_correct = 0
        for data, label in self.test:
            data, label = Variable(data.cpu()), Variable(label.cpu())
            if self.model == "LSTM":
                data = data.view(-1, 28, 28)
            outputs = self.net(data)
            loss = self.criterion(outputs, label)
            test_loss += loss * Config.batch_size
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct

        accuracy = test_correct / len(self.test.dataset)
        loss = test_loss / len(self.test.dataset)
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))


if __name__ == "__main__":
    p = TrainProcess(model='CNN')
    p.train_step()


Peace~~

o
粉丝 1
博文 500
码字总数 0
作品 0
私信 提问
加载中
请先登录后再评论。
MNIST数据集深度学习实践汇总

Why MNIST MNIST数据集对深度学习初学者来说应该是最友好的数据集了: 拿来即用,你只需要专注于模型搭建就好(数据处理真的很费时间); 数据集不大,很适合普通玩家,一般的PC都能跑的动,...

Awesome_Tang
2019/05/02
0
0
基于Pytorch的简单小案例

  神经网络的理论知识不是本文讨论的重点,假设读者们都是已经了解RNN的基本概念,并希望能用一些框架做一些简单的实现。这里推荐神经网络必读书目:邱锡鹏《神经网络与深度学习》。本文基...

osc_neajl7oq
2019/04/17
5
0
深度学习引擎-PyTorch资源集锦

《深度学习引擎-PyTorch资源集锦》包含大量与 PyTorch (https://pytorch.org/)相关的资源链接,带你快速玩转基于神经网络的深度学习,进入人工智能的神秘领地。链接包括:入门教程,应用实...

openthings
2019/03/09
181
0
S01: 手写深度学习框架

手写深度学习框架 笔者手撸了简单的深度学习框架,这个小项目源于笔者学习pytorch的过程中对autograd的探索。项目名称为kitorch。 该项目基于numpy实现,代码的执行效率比cpu的pytorch要慢。...

oio328Loio
03/31
0
0
吐血整理:PyTorch项目代码与资源列表 | 资源下载

http://www.sohu.com/a/164171974_741733 本文收集了大量基于 PyTorch 实现的代码链接,其中有适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文代码实现,包括 Attention Base...

osc_tek5189e
2018/03/02
7
0

没有更多内容

加载失败,请刷新页面

加载更多

PO设计模式-实现移动端自动化测试

开发环境:python 3.6.5 + selenium 2.48.0 + pytest框架 + Android 5.1 工具:pycharm + Appium + Genymotion 测试机型:Samsung Galaxy S6 #需求:设计3个测试用例#1.实现点击设置->显示-...

osc_cl1ufvfd
20分钟前
20
0
Android之TabLayout和ViewPager组合跳转到指定页面

1 问题 TabLayout和ViewPager组合跳转到具体一个页面 2 解决办法 viewPager?.setCurrentItem(index) index为0说明是第一页,如果是1的话就是第二页,以此类推。...

osc_w4g8kpwc
22分钟前
17
0
Android之解决多语言适配部分TextView内容左对齐和内容一行不排满就到第二行问题

1 问题 1、多语言适配部分TextView内容左对齐 2、内容一行不排满就到第二行问题 2 解决办法 问题1、在TextView里面加入下面参数 android:gravity="center" 问题2、 import android.conte...

osc_u61lmlkv
23分钟前
17
0
SpringBoot2.0+Shiro+MyBatisPlus权限管理系统

项目描述 Hi,大家好,今天分享的项目是《SpringBoot+Shiro权限管理系统》,这是一个SpringBoot+Layui后台管理系统,使用Shiro安全框架,加入访问权限,对不同角色有不同的访问权限,其他管理...

ericxu1116
24分钟前
19
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部