文档章节

PyTorch快速入门教程二(线性回归以及logistic回归)

earnpls
 earnpls
发布于 2017/07/02 09:08
字数 965
阅读 212
收藏 0
点赞 0
评论 0

线性回归

对于线性回归,相信大家都很熟悉了,各种机器学习的书第一个要讲的内容必定有线性回归,这里简单的回顾一下什么是简单的一元线性回归。即给出一系列的点,找一条直线,使得这条直线与这些点的距离之和最小。

什么是线性回归

上图就简单地描绘出了线性回归的基本原理,接下来我们来重点讲讲如何用pytorch写一个简单的线性回归。

Data参数

首先我们需要给出一系列的点作为线性回归的数据,使用numpy来存储这些点。如果没有安装Numpy,可以参考这篇文章:Python如何使用PIP安装numpy

x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
                    [9.779], [6.182], [7.59], [2.167], [7.042],
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)

y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
                    [3.366], [2.596], [2.53], [1.221], [2.827],
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

显示出来就如下图所示

Pytorch线性回归

上一讲我们已经学习如何将numpy转换成Tensor,使用torch.from_numpy()numpy函数即可进行相互转换

x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

这样我们就把数据转换成了Tensor。

Model

上一节讲了基本的线性回归以及模型框架,按照这个框架就可以写出一个线性回归模型了

# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 1 dimension

    def forward(self, x):
        out = self.linear(x)
        return out
model = LinearRegression()

这里的nn.Linear表示的是 y=w*x+b,里面的两个参数都是1,表示的是x是1维,y也是1维。当然这里是可以根据你想要的输入输出维度来更改的,之前使用的别的框架的同学应该很熟悉。

然后需要定义lossoptimizer,就是误差和优化函数

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4)

这里使用的是最小二乘loss,之后我们做分类问题更多的使用的是cross entropy loss,交叉熵。优化函数使用的是随机梯度下降,注意需要将model的参数model.parameters()传进去让这个函数知道他要优化的参数是哪些。

开始训练

我们接着开始训练

num_epochs = 1000
for epoch in range(num_epochs):
    inputs = Variable(x_train)
    target = Variable(y_train)

    # forward
    out = model(inputs) # 前向传播
    loss = criterion(out, target) # 计算loss
    # backward
    optimizer.zero_grad() # 梯度归零
    loss.backward() # 方向传播
    optimizer.step() # 更新参数

    if (epoch+1) % 20 == 0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data[0]))

第一个循环表示每个epoch,接着开始前向传播,然后计算loss,然后反向传播,接着优化参数,特别注意的是在每次反向传播的时候需要将参数的梯度归零,即

optimzier.zero_grad()

validation

训练完成之后我们就可以开始测试模型了

model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()

特别注意的是需要用 model.eval(),让model变成测试模式,这主要是对dropoutbatch normalization的操作在训练和测试的时候是不一样的

最后可以得到这个结果 Pytorch validation 以及loss的结果 Pytorch validation

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

 

本文转载自:https://ptorch.com/news/6.html

共有 人打赏支持
earnpls
粉丝 5
博文 26
码字总数 74
作品 0
昌平
程序员
终于!大家心心念念的PyTorch Windows官方支持来了

  机器之心整理   参与:机器之心编辑部      五个小时前,PyTorch 官方 GitHub 发布 0.4.0 版本,大家心心念念的 Windows 支持终于来了。      GitHub 发布https://github.com/...

机器之心 ⋅ 04/25 ⋅ 0

业界 | 无缝整合PyTorch 0.4与Caffe2,PyTorch 1.0即将问世

  选自Facebook Research   作者:Bill Jia   机器之心编译   参与:思源、晓坤      在 F8 的第二天中,Facebook 正式宣布 PyTorch1.0 即将与大家见面,这是继一周前发布 0.4....

机器之心 ⋅ 05/03 ⋅ 0

PyTorch 重大更新,0.4.0 版本支持 Windows 系统

雷锋网(公众号:雷锋网) AI 研习社最新消息,PyTorch 官方发布 0.4.0 版本,该版本的 PyTorch 有多项重大更新,其中最重要的改进是支持 Window 系统。 2017 年初,Facebook 在机器学习和科学...

孔令双 ⋅ 04/25 ⋅ 0

一文读懂PyTorch张量基础(附代码)

本文介绍了PyTorch中的Tensor类,它类似于Numpy中的ndarray,它构成了在PyTorch中构建神经网络的基础。 我们已经知道张量到底是什么了,并且知道如何用Numpy的ndarray来表示它们,现在我们看...

技术小能手 ⋅ 06/13 ⋅ 0

PyTorch 1.0 正式公开,Caffe2并入PyTorch实现AI研究和生产一条龙

今天,Facebook正式公布PyTorch 1.0,这是将基于Python的PyTorch与Caffe2合并的一个新版本的框架,让开发者可以无缝地将AI模型从研究转到生产,而无需处理迁移 “现在,你只需要使用PyTorch...

技术小能手 ⋅ 05/03 ⋅ 0

pytorch学习1:环境的搭建

环境搭建 ubuntu14.04+anaconda2+python2.7 首先在conda中新建一个环境: conda create --name pytorch_learn python=2.7 进入该环境: source activate pytorch_learn 安装pytorch,(可参考......

chenyue_tju ⋅ 05/06 ⋅ 0

机器学习者必知的5种深度学习框架

雷锋网按:本文为雷锋字幕组编译的技术博客,原标题The 5 Deep Learning Frameworks Every Serious Machine Learner Should Be Familiar With,作者为James Le。 翻译 | 杨恕权 张晓雪 陈明霏...

雷锋字幕组 ⋅ 05/03 ⋅ 0

Caffe2 公布与 PyTorch 合并细节:只为提高开发效率

Caffe2 近日在其博客上公布了与 PyTorch 合并的各项细节,文中表示 Caffe2 的开发重点是性能和跨平台部署,而 PyTorch 则专注于快速原型设计和研究的灵活性。二者的组件在过去一年大量被共享...

王练 ⋅ 05/06 ⋅ 0

融合 Caffe2、ONNX 的新版 PyTorch 发布在即,能否赶超 TensorFlow?

雷锋网(公众号:雷锋网) AI 研习社按,上个月,Caffe2 代码正式并入 PyTorch,就在今天,Facebook AI 系统与平台部(AI Infra and Platform)副总 Bill Jia 发文表示,PyTorch 1.0 发布在即,...

思颖 ⋅ 05/03 ⋅ 0

keras实战(二)——手写数字识别

上一篇博文里,详细解说了keras线性回归的应用,容易看到,相比tensorflow,pytorch等,代码量少,容易快速搭建神经网络。本篇博文尝试用keras实现mnist手写数字的识别。 先说一下什么是mni...

cuicheng01 ⋅ 05/13 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

tcp/ip详解-链路层

简介 设计链路层的目的: 为IP模块发送和接收IP数据报 为ARP模块发送ARP请求和接收ARP应答 为RARP模块发送RARP请求和接收RARP应答 TCP/IP支持多种链路层协议,如以太网、令牌环往、FDDI、RS-...

loda0128 ⋅ 今天 ⋅ 0

spring.net aop代码例子

https://www.cnblogs.com/haogj/archive/2011/10/12/2207916.html

whoisliang ⋅ 今天 ⋅ 0

发送短信如何限制1小时内最多发送11条短信

发送短信如何限制1小时内最多发送11条短信 场景: 发送短信属于付费业务,有时为了防止短信攻击,需要限制发送短信的频率,例如在1个小时之内最多发送11条短信. 如何实现呢? 思路有两个 截至到当...

黄威 ⋅ 昨天 ⋅ 0

mysql5.7系列修改root默认密码

操作系统为centos7 64 1、修改 /etc/my.cnf,在 [mysqld] 小节下添加一行:skip-grant-tables=1 这一行配置让 mysqld 启动时不对密码进行验证 2、重启 mysqld 服务:systemctl restart mysql...

sskill ⋅ 昨天 ⋅ 0

Intellij IDEA神器常用技巧六-Debug详解

在调试代码的时候,你的项目得debug模式启动,也就是点那个绿色的甲虫启动服务器,然后,就可以在代码里面断点调试啦。下面不要在意,这个快捷键具体是啥,因为,这个keymap是可以自己配置的...

Mkeeper ⋅ 昨天 ⋅ 0

zip压缩工具、tar打包、打包并压缩

zip 支持压缩目录 1.在/tmp/目录下创建目录(study_zip)及文件 root@yolks1 study_zip]# !treetree 11└── 2 └── 3 └── test_zip.txt2 directories, 1 file 2.yum...

蛋黄Yolks ⋅ 昨天 ⋅ 0

聊聊HystrixThreadPool

序 本文主要研究一下HystrixThreadPool HystrixThreadPool hystrix-core-1.5.12-sources.jar!/com/netflix/hystrix/HystrixThreadPool.java /** * ThreadPool used to executed {@link Hys......

go4it ⋅ 昨天 ⋅ 0

容器之上传镜像到Docker hub

Docker hub在国内可以访问,首先要创建一个账号,这个后面会用到,我是用126邮箱注册的。 1. docker login List-1 Username不能使用你注册的邮箱,要用使用注册时用的username;要输入密码 ...

汉斯-冯-拉特 ⋅ 昨天 ⋅ 0

SpringBoot简单使用ehcache

1,SpringBoot版本 2.0.3.RELEASE ①,pom.xml <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.0.3.RELE......

暗中观察 ⋅ 昨天 ⋅ 0

Spring源码解析(八)——实例创建(下)

前言 来到实例创建的最后一节,前面已经将一个实例通过不同方式(工厂方法、构造器注入、默认构造器)给创建出来了,下面我们要对创建出来的实例进行一些“加工”处理。 源码解读 回顾下之前...

MarvelCode ⋅ 昨天 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部