文档章节

【深度学习框架】使用PyTorch进行数据处理

o
 osc_w9s1w4o0
发布于 2019/03/30 17:48
字数 1492
阅读 5
收藏 0

「深度学习福利」大神带你进阶工程师,立即查看>>>

  在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像、文本、语音等)处理不仅可以加速模型的训练,同时也直接关系到模型的效果。本文以处理图像数据为例,记录一些使用PyTorch进行图像预处理和数据加载的方法


一、数据的加载

  在PyTorch中,数据加载需要自定义数据集类,并用此类来实例化数据对象,实现自定义的数据集需要继承torch.utils.data包中的Dataset类。   在继承Dataset实现自己的类时,需要实现以下两个Python魔法方法:

  • __getitem__(index): 返回一个样本数据,当使用obj[index]时实际就是在调用obj.__getitem__(index)
  • __len__():返回样本的数量,当使用len(obj)时实际就是在调用obj.__len__()

  例如,以猫狗大战的二分类数据集为例,其加载过程如下:

import os
import torch as t
from torch.utils import data
from PIL import Image
import numpy as np

class dogCat(data.Dataset):
    def __init__(self,root): # root为数据存放目录
        imgs = os.listdir(root) #列出当前路径下所有的文件
        self.imgs = [os.path.join(root,img) for img in imgs] # 所有图片的路径
        #print(self.imgs)

	"""返回一个样本数据"""
    def __getitem__(self, item): 
        img_path = self.imgs[item] # 第item张图片的路径
        #dog 1 cat 0
        label = 1 if 'dog' in img_path.split('\\')[-1] else 0 # 获取标签信息
        #print(label)
        pil_img = Image.open(img_path) #读入图片
        print(type(pil_img))
        array = np.asarray(pil_img) # 转为numpy.array类型
        data = t.from_numpy(array) # 转为tensor类型
        return data,label #返回图片对应的tensor及其标签

	"""样本的数量"""
    def __len__(self):
        return len(self.imgs)

if __name__ == '__main__':
    dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train') #数据集对象
    data,label = dogcat[0] # 返回第0张图片的信息
    print(data.size())
    print(label)
    print(len(dogcat))

二、计算机视觉工具包:torchvision

  对于图像数据来说,以上的数据加载时不完善的,因为只是将图片读入,而没有进行相关的处理,如每张图片的大小和形状,样本的数值归一化等等。   为了解决这一问题,PyTorch开发了一个视觉工具包torchvision,这个包独立于torch,需要通过pip install torchvision来单独安装。   torchvision有三个部分组成:

  • models提供各种经典的网络结构和预训练好的模型,如AlexNet、VGG、ResNet、Inception等
from torchvision import models
from torch import nn
resnet34 = models.resnet34(pretrained=True,num_classes=1000) # 加载预训练模型
resnet34.fc=nn.Linear(512,10) # 修改全连接层为10分类
  • datasets提供了常用的数据集,如MNIST、CIFAR10/100、ImageNet、COCO等
from torchvision import datasets
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)

  除了常用数据集外,需要特别注意的是ImageFolder,ImageFolder假设所有的文件按文件夹存放,每个文件夹下面存储同一类的图片,文件夹的名字为这一类别的名字。这是我们经常用到的一种数据组织形式。

# 使用方法:
ImageFolder(root,transform=None,target_transform=None,loader=default_loader)
# 参数:文件夹路径,对图像做什么样的转换,对标签做什么样的转换,如何加载图片

from torchvision.datasets import ImageFolder
dataset = ImageFolder('data\\')
print(dataset.class_to_idx) # class_to_idx ,label和id的对应关系,从0开始
print(dataset.imgs) # 数据和标签对应
  • transforms: 提供常用的数据预处理操作,主要是对Tensor和PIL Image对象的处理操作

  对PIL Image的操作:Resize、CenterCrop、RandomCrop、RandomsizedCrop、Pad、ToTensor等。

  对Tensor的操作:Normalize、ToPILImage等。

  如果要进行多个操作,可以通过transforms.Compose([])将操作拼接起来。但是需要注意的是需要首先构建转换操作,然后再执行转换操作。

import os
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms as T

transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])  # 构建转换操作

class dogCat(data.Dataset):
    def __init__(self,root,transforms):
        imgs = os.listdir(root)
        #print(imgs)
        self.imgs = [os.path.join(root,img) for img in imgs]
        #print(self.imgs)
        self.transforms = transforms

    def __getitem__(self, item):
        img_path = self.imgs[item]
        #dog 1 cat 0
        label = 1 if 'dog' in img_path.split('\\')[-1] else 0
        #print(label)
        pil_img = Image.open(img_path)
        if self.transforms:
            pil_img = self.transforms(pil_img)  #执行准换操作
        return pil_img,label,item

    def __len__(self):
        return len(self.imgs)

三、使用DataLoader进行数据再处理

  通过上述描述,我们通过自定义数据集类,使用视觉工具包进行图像的转换等操作,最终得到的是一个dataset的数据集对象,使用此对象可以一次返回一个样本。   但是,我们应该清楚:训练神经网络时,一般采用的是小批量的梯度下降,因此我们是对一批数据进行处理,也就是一个batch,同时,数据还需要进行打乱(shuffle)和并行加速等。PyTorch提供了DataLoader来实现这些功能。   DataLoader定义如下:

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)

  参数含义如下:

  • dataset:加载的数据集
  • batch_zize: 批大小
  • shuffle: 是否将数据打乱
  • sampler:样本抽样,常用的有随机采样RandomSampler,shuffle=True时自动调用随机采样,默认是顺序采样,还有一个常用的是:WeightedRandomSampler,按照样本的权重进行采样。
  • num_workers: 使用的进程数,0代表不使用多进程。
  • collate_fn: 拼接方式。
  • pin_memory: 是否将数据保存在pin memory区。
  • drop_last: 是否将多出来的不足一个batch的丢弃。

  调用DataLoader得到的结果是一个可迭代的对象,可以和使用迭代器一样使用它。

from torchvision import transforms as T
from torch.utils.data import DataLoader

transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])

if __name__ == '__main__':
    dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train', transform)
    data, label, index = dogcat[0]
    
    dataloader = DataLoader(dogcat,batch_size=3,shuffle=False,num_workers=0,drop_last=False)
    for batchDatas,batchLabels in dataloader: 
        train()

总结

  本文记录了使用PyTorch进行数据预处理的相关操作流程,重点是掌握Dataset和DataLoader两个类的使用,另外,视觉工具包torchvision的三个模块灵活运用,会对数据处理过程有很好的帮助。

o
粉丝 0
博文 500
码字总数 0
作品 0
私信 提问
加载中
请先登录后再评论。
用vertx实现高吞吐量的站点计数器

工具:vertx,redis,mongodb,log4j 源代码地址:https://github.com/jianglibo/visitrank 先看架构图: 如果你不熟悉vertx,请先google一下。我这里将vertx当作一个容器,上面所有的圆圈要...

jianglibo
2014/04/03
4.2K
3
CDH5: 使用parcels配置lzo

一、Parcel 部署步骤 1 下载: 首先需要下载 Parcel。下载完成后,Parcel 将驻留在 Cloudera Manager 主机的本地目录中。 2 分配: Parcel 下载后,将分配到群集中的所有主机上并解压缩。 3 激...

cloud-coder
2014/07/01
6.8K
1
beego API开发以及自动化文档

beego API开发以及自动化文档 beego1.3版本已经在上个星期发布了,但是还是有很多人不了解如何来进行开发,也是在一步一步的测试中开发,期间QQ群里面很多人都问我如何开发,我的业余时间实在...

astaxie
2014/06/25
2.7W
22
半同步/半异步的Tcp Server--LightningServer

这是一个半同步/半异步的Tcp Server. 支持以下特性: 1.使用了libevent库,支持大并发网络请求; 2.网络操作与数据处理分离; 3.使用线程池进行数据处理; 4.目前支持tcp数据流的解包操作: 4....

扫帚的影子
2012/12/24
2.8K
0
权限控制框架--authorityFilter

基于java 过滤器(Filter)实现对权限控制的框架。 依赖jar:log4j.jar,fastjson.jar 软件由三部分组成: 权限过滤器AuthorityFilter # 负责过滤url并执行权限检查器中的权限验证方法(check...

寻觅一只耳朵
2013/05/05
3.1K
0

没有更多内容

加载失败,请刷新页面

加载更多

如何使用jQuery获取元素的ID? - How can I get the ID of an element using jQuery?

问题: <div id="test"></div><script> $(document).ready(function() { alert($('#test').id); }); </script> Why doesn't the above work, and how should I do this? 为什么上......

技术盛宴
31分钟前
11
0
为什么在允许某些Unicode字符的注释中执行Java代码?

问题: The following code produces the output "Hello World!" 以下代码生成输出“Hello World!” (no really, try it). (不,真的,试试吧)。 public static void main(String... args......

富含淀粉
今天
8
0
字符串格式:%与.format - String formatting: % vs. .format

问题: Python 2.6 introduced the str.format() method with a slightly different syntax from the existing % operator. Python 2.6引入了str.format()方法,其语法与现有的%运算符略有不......

javail
今天
22
0
什么是按位移位(位移)运算符以及它们如何工作? - What are bitwise shift (bit-shift) operators and how do they work?

问题: I've been attempting to learn C in my spare time, and other languages (C#, Java, etc.) have the same concept (and often the same operators) ... 我一直在尝试在业余时间学习......

法国红酒甜
今天
32
0
OSChina 周二乱弹 —— 卧槽 李荣浩的契约兽啊

Osc乱弹歌单(2020)请戳(这里) 【今日歌曲】 @薛定谔的兄弟 :分享洛神有语创建的歌单「我喜欢的音乐」: 《红色的回忆》- 痛仰乐队 手机党少年们想听歌,请使劲儿戳(这里) 动弹, 又好多...

小小编辑
今天
61
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部