文档章节

PyTorch快速入门教程二(线性回归以及logistic回归)

earnpls
 earnpls
发布于 2017/07/02 09:08
字数 965
阅读 287
收藏 0

线性回归

对于线性回归,相信大家都很熟悉了,各种机器学习的书第一个要讲的内容必定有线性回归,这里简单的回顾一下什么是简单的一元线性回归。即给出一系列的点,找一条直线,使得这条直线与这些点的距离之和最小。

什么是线性回归

上图就简单地描绘出了线性回归的基本原理,接下来我们来重点讲讲如何用pytorch写一个简单的线性回归。

Data参数

首先我们需要给出一系列的点作为线性回归的数据,使用numpy来存储这些点。如果没有安装Numpy,可以参考这篇文章:Python如何使用PIP安装numpy

x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
                    [9.779], [6.182], [7.59], [2.167], [7.042],
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)

y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
                    [3.366], [2.596], [2.53], [1.221], [2.827],
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

显示出来就如下图所示

Pytorch线性回归

上一讲我们已经学习如何将numpy转换成Tensor,使用torch.from_numpy()numpy函数即可进行相互转换

x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

这样我们就把数据转换成了Tensor。

Model

上一节讲了基本的线性回归以及模型框架,按照这个框架就可以写出一个线性回归模型了

# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 1 dimension

    def forward(self, x):
        out = self.linear(x)
        return out
model = LinearRegression()

这里的nn.Linear表示的是 y=w*x+b,里面的两个参数都是1,表示的是x是1维,y也是1维。当然这里是可以根据你想要的输入输出维度来更改的,之前使用的别的框架的同学应该很熟悉。

然后需要定义lossoptimizer,就是误差和优化函数

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4)

这里使用的是最小二乘loss,之后我们做分类问题更多的使用的是cross entropy loss,交叉熵。优化函数使用的是随机梯度下降,注意需要将model的参数model.parameters()传进去让这个函数知道他要优化的参数是哪些。

开始训练

我们接着开始训练

num_epochs = 1000
for epoch in range(num_epochs):
    inputs = Variable(x_train)
    target = Variable(y_train)

    # forward
    out = model(inputs) # 前向传播
    loss = criterion(out, target) # 计算loss
    # backward
    optimizer.zero_grad() # 梯度归零
    loss.backward() # 方向传播
    optimizer.step() # 更新参数

    if (epoch+1) % 20 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data[0]))

第一个循环表示每个epoch,接着开始前向传播,然后计算loss,然后反向传播,接着优化参数,特别注意的是在每次反向传播的时候需要将参数的梯度归零,即

optimzier.zero_grad()

validation

训练完成之后我们就可以开始测试模型了

model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()

特别注意的是需要用 model.eval(),让model变成测试模式,这主要是对dropoutbatch normalization的操作在训练和测试的时候是不一样的

最后可以得到这个结果 Pytorch validation 以及loss的结果 Pytorch validation

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

 

本文转载自:https://ptorch.com/news/6.html

共有 人打赏支持
earnpls
粉丝 6
博文 26
码字总数 74
作品 0
昌平
程序员
私信 提问
PyTorch 你想知道的都在这里

本文转载地址,并进行了加工。本文适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文代码实现,包括 Attention Based CNN、A3C、WGAN、BERT等等。所有代码均按照所属技术领域分...

readilen
10/20
0
0
PyTorch:60分钟入门学习

最近在学习PyTorch这个深度学习框架,在这里做一下整理分享给大家,有什么写的不对或者不好的地方,还请大侠们见谅啦~~~ 写在前面 本文就是主要是对PyTorch的安装,以及入门学习做了记录,...

与阳光共进早餐
01/15
0
0
库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

选自 Github,作者:bharathgs,机器之心编译。 机器之心发现了一份极棒的 PyTorch 资源列表,该列表包含了与 PyTorch 相关的众多库、教程与示例、论文实现以及其他资源。在本文中,机器之心...

机器之心
10/22
0
0
PyTorch 1.0 稳定版来啦

雷锋网 AI 科技评论消息,NeurIPS 正在加拿大召开,会上,Facebook 宣布正式推出 PyTorch 1.0 稳定版,在 Facebook code 博客上,也一并同步了这一消息。雷锋网(公众号:雷锋网) AI 科技评论...

汪思颖
12/08
0
0
Keras vs PyTorch:谁是「第一」深度学习框架?

  选自Deepsense.ai   作者:Rafa Jakubanis、Piotr Migdal   机器之心编译   参与:路、李泽南、李亚洲      「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题...

机器之心
06/30
0
0

没有更多内容

加载失败,请刷新页面

加载更多

mybatis学习(1)

JDBC连接方式: 1.底层没有使用连接池,操作数据库需要频繁的创建和关闭连接,消耗资源。 2.写原生的JDBC代码在JAVA中,一旦需要修改SQL的话(比如表增加字段),JAVA需要整体重新编译,不利...

杨健-YJ
38分钟前
2
0
怎么组织文档

可以从以下几个方面考虑组织文档: ☐ 各种分支的界面截图和对应的类及文件 ☐ 框架或类图 ☐ 流程图 ☐ 时序图 ☐ 注意事项

-___-
49分钟前
3
0
分布式之数据库和缓存双写一致性方案解析

引言 为什么写这篇文章? 首先,缓存由于其高并发和高性能的特性,已经在项目中被广泛使用。在读取缓存方面,大家没啥疑问,都是按照下图的流程来进行业务操作。 但是在更新缓存方面,对于更...

别打我会飞
52分钟前
10
0
我的oracle11G,12c OCM之路

ocm认证感悟 ---------------------- 距离拿到ocm证书已经过了1年的时间,当初拿到证书的心情到现在还记得。其实在每个DBA心里都有一个成为强者的梦想,需要被认可,我也一样。我干过开发,做...

hnairdb
52分钟前
2
1
手动部署kubernetes集群(1.13.1最新版)

一、机器规划 使用五台机子部署k8s集群,规划如下: master节点3台(同时也是etcd节点) node节点2台 ip分配如下: ip:192.168.10.101,主机名:k8s-etcd01 ip:192.168.10.102,主机名:k8s...

人在艹木中
57分钟前
31
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部