gitbook/PyTorch深度学习实战/docs/429048.md
2022-09-03 22:05:03 +08:00

15 KiB
Raw Permalink Blame History

06 | Torchvision数据读取训练开始的第一步

你好,我是方远。

今天起我们进入模型训练篇的学习。如果将模型看作一辆汽车,那么它的开发过程就可以看作是一套完整的生产流程,环环相扣、缺一不可。这些环节包括数据的读取、网络的设计、优化方法与损失函数的选择以及一些辅助的工具等。未来你将尝试构建自己的豪华汽车,或者站在巨人的肩膀上对前人的作品进行优化。

试想一下如果你对这些基础环节所使用的方法都不清楚你还能很好地进行下去吗所以通过这个模块我们的目标是先把基础打好。通过这模块的学习对于PyTorch都为我们提供了哪些丰富的API你就会了然于胸了。

Torchvision 是一个和 PyTorch 配合使用的 Python 包包含很多图像处理的工具。我们先从数据处理入手开始PyTorch的学习的第一步。这节课我们会先介绍Torchvision的常用数据集及其读取方法在后面的两节课里我再带你了解常用的图像处理方法与Torchvision其它有趣的功能。

PyTorch中的数据读取

训练开始的第一步首先就是数据读取。PyTorch为我们提供了一种十分方便的数据读取机制即使用Dataset类与DataLoader类的组合来得到数据迭代器。在训练或预测时数据迭代器能够输出每一批次所需的数据并且对数据进行相应的预处理与数据增强操作。

下面我们分别来看下Dataset类与DataLoader类。

Dataset类

PyTorch中的Dataset类是一个抽象类它可以用来表示数据集。我们通过继承Dataset类来自定义数据集的格式、大小和其它属性后面就可以供DataLoader类直接使用。

其实这就表示无论使用自定义的数据集还是官方为我们封装好的数据集其本质都是继承了Dataset类。而在继承Dataset类时至少需要重写以下几个方法

  • __init__():构造函数,可自定义数据读取方法以及进行数据预处理;
  • __len__():返回数据集大小;
  • __getitem__():索引数据集中的某一个数据。

光看原理不容易理解下面我们来编写一个简单的例子看下如何使用Dataset类定义一个Tensor类型的数据集。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)
    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

结合代码可以看到我们定义了一个名字为MyDataset的数据集在构造函数中传入Tensor类型的数据与标签在__len__函数中直接返回Tensor的大小在__getitem__函数中返回索引的数据与标签。

下面我们来看一下如何调用刚才定义的数据集。首先随机生成一个10*3维的数据Tensor然后生成10维的标签Tensor与数据Tensor相对应。利用这两个Tensor生成一个MyDataset的对象。查看数据集的大小可以直接用len()函数,索引调用数据可以直接使用下标。

# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''

# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
'''
输出:
tensor_data[0]:  (tensor([ 0.4931, -0.0697,  0.4171]), tensor(0))
'''

DataLoader类

在实际项目中如果数据量很大考虑到内存有限、I/O速度等问题在训练过程中不可能一次性的将所有数据全部加载到内存中也不能只用一个进程去加载所以就需要多进程、迭代加载而DataLoader就是基于这些需要被设计出来的。

DataLoader是一个迭代器最基本的使用方法就是传入一个Dataset对象它会根据参数 batch_size的值生成一个batch的数据节省内存的同时它还可以实现多进程、数据打乱等处理。

DataLoader类的调用方式如下

from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
                               batch_size=2,       # 输出的batch大小
                               shuffle=True,       # 数据是否打乱
                               num_workers=0)      # 进程数, 0表示只有主进程

# 以循环形式输出
for data, target in tensor_dataloader: 
    print(data, target)
'''
输出:
tensor([[-0.1781, -1.1019, -0.1507],
        [-0.6170,  0.2366,  0.1006]]) tensor([0, 0])
tensor([[ 0.9451, -0.4923, -1.8178],
        [-0.4046, -0.5436, -1.7911]]) tensor([0, 0])
tensor([[-0.4561, -1.2480, -0.3051],
        [-0.9738,  0.9465,  0.4812]]) tensor([1, 0])
tensor([[ 0.0260,  1.5276,  0.1687],
        [ 1.3692, -0.0170, -1.6831]]) tensor([1, 0])
tensor([[ 0.0515, -0.8892, -0.1699],
        [ 0.4931, -0.0697,  0.4171]]) tensor([1, 0])
'''
 
# 输出一个batch
print('One batch tensor data: ', iter(tensor_dataloader).next())
'''
输出:
One batch tensor data:  [tensor([[ 0.9451, -0.4923, -1.8178],
        [-0.4046, -0.5436, -1.7911]]), tensor([0, 0])]
'''

结合代码我们梳理一下DataLoader中的几个参数它们分别表示

  • datasetDataset类型输入的数据集必须参数
  • batch_sizeint类型每个batch有多少个样本
  • shufflebool类型在每个epoch开始的时候是否对数据进行重新打乱
  • num_workersint类型加载数据的进程数0意味着所有的数据都会被加载进主进程默认为0。

什么是Torchvision

PyTroch官方为我们提供了一些常用的图片数据集如果你需要读取这些数据集那么无需自己实现只需要利用Torchvision就可以搞定。

Torchvision 是一个和 PyTorch 配合使用的 Python 包。它不只提供了一些常用数据集还提供了几个已经搭建好的经典网络模型以及集成了一些图像数据处理方面的工具主要供数据预处理阶段使用。简单地说Torchvision 库就是常用数据集+常见网络模型+常用图像处理方法

Torchvision的安装方式同样非常简单可以使用conda安装命令如下

conda install torchvision -c pytorch

或使用pip进行安装命令如下

pip install torchvision

Torchvision中默认使用的图像加载器是PIL因此为了确保Torchvision正常运行我们还需要安装一个Python的第三方图像处理库——Pillow库。Pillow提供了广泛的文件格式支持强大的图像处理能力主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。

使用conda安装Pillow的命令如下

conda install pillow

使用pip安装Pillow的命令如下

pip install pillow

利用Torchvision读取数据

安装好Torchvision之后我们再来接着看看。Torchvision库为我们读取数据提供了哪些支持。

Torchvision库中的torchvision.datasets包中提供了丰富的图像数据集的接口。常用的图像数据集例如MNIST、COCO等这个模块都为我们做了相应的封装。

下表中列出了torchvision.datasets包所有支持的数据集。各个数据集的说明与接口,详见链接https://pytorch.org/vision/stable/datasets.html

图片

这里我想提醒你注意,torchvision.datasets这个包本身并不包含数据集的文件本身,它的工作方式是先从网络上把数据集下载到用户指定目录,然后再用它的加载器把数据集加载到内存中。最后,把这个加载后的数据集作为对象返回给用户。

为了让你进一步加深对知识的理解我们以MNIST数据集为例来说明一下这个模块具体的使用方法。

MNIST数据集简介

MNIST数据集是一个著名的手写数字数据集因为上手简单在深度学习领域手写数字识别是一个很经典的学习入门样例。

MNIST数据集是NIST数据集的一个子集MNIST 数据集你可以通过这里下载。它包含了四个部分,我用表格的方式为你做了梳理。

图片

MNIST数据集是ubyte格式存储我们先将“训练集图片”解析成图片格式来直观地看一看数据集具体是什么样子的。具体怎么解析我在后面数据预览再展开。

图片

数据读取

接下来我们看一下如何使用Torchvision来读取MNIST数据集。

对于torchvision.datasets所支持的所有数据集它都内置了相应的数据集接口。例如刚才介绍的MNIST数据集torchvision.datasets就有一个MNIST的接口接口内封装了从下载、解压缩、读取数据、解析数据等全部过程。

这些接口的工作方式差不多,都是先从网络上把数据集下载到指定目录,然后再用加载器把数据集加载到内存中,最后将加载后的数据集作为对象返回给用户。

以MNIST为例我们可以用如下方式调用

# 以MNIST为例
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data',
                                       train=True,
                                       transform=None,
                                       target_transform=None,
                                       download=True)

torchvision.datasets.MNIST是一个类对它进行实例化即可返回一个MNIST数据集对象。构造函数包括包含5个参数

  • root是一个字符串用于指定你想要保存MNIST数据集的位置。如果download是Flase则会从目标位置读取数据集
  • download是布尔类型表示是否下载数据集。如果为True则会自动从网上下载这个数据集存储到root指定的位置。如果指定位置已经存在数据集文件则不会重复下载
  • train是布尔类型表示是否加载训练集数据。如果为True则只加载训练数据。如果为False则只加载测试数据集。这里需要注意,并不是所有的数据集都做了训练集和测试集的划分,这个参数并不一定是有效参数,具体需要参考官方接口说明文档
  • transform用于对图像进行预处理操作例如数据增强、归一化、旋转或缩放等。这些操作我们会在下节课展开讲解
  • target_transform用于对图像标签进行预处理操作。

运行上述的代码我们可以得到下图所示的效果。从图中我们可以看出程序首先去指定的网址下载了MNIST数据集然后进行了解压缩等操作。如果你再次运行相同的代码则不会再有下载的过程。

图片

看到这你可能还有疑问好奇我们得到的mnist_dataset是什么呢

如果你用type函数查看一下mnist_dataset的类型就可以得到torchvision.datasets.mnist.MNIST 而这个类是之前我们介绍过的Dataset类的派生类。相当于torchvision.datasets 它已经帮我们写好了对Dataset类的继承完成了对数据集的封装我们直接使用即可。

这里我们主要以MNIST为例进行了说明。其它的数据集使用方法类似调用的时候你只要需要将类名“MNIST”换成其它数据集名字即可。

对于不同的数据集,数据格式都不尽相同,而torchvision.datasets则帮助我们完成了各种不同格式的数据的解析与读取,可以说十分便捷。而对于那些没有官方接口的图像数据集,我们也可以使用以torchvision.datasets.ImageFolder接口来自行定义在图像分类的实战篇中就是使用ImageFolder进行数据读取的你可以到那个时候再看一看。

数据预览

完成了数据读取工作我们得到的是对应的mnist_dataset刚才已经讲过了这是一个封装了的数据集。

如果想要查看mnist_dataset中的具体内容我们需要把它转化为列表。如果IOPub data rate超限可以只加载测试集数据令train=False

mnist_dataset_list = list(mnist_dataset)
print(mnist_dataset_list)

执行结果如下图所示。

图片

从运行结果中可以看出,转换后的数据集对象变成了一个元组列表,每个元组有两个元素,第一个元素是图像数据,第二个元素是图像的标签。

这里图像数据是PIL.Image.Image类型的这种类型可以直接在Jupyter中显示出来。显示一条数据的代码如下。

display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])

运行结果如下图所示。可以看出数据集mnist_dataset中的第一条数据是图片手写数字“7”对应的标签是“7”。

图片

好,如果你也得到了上面的运行结果,说明你的操作没问题,恭喜你成功完成了读取操作。

小结

恭喜你完成了这节课的学习。我们已经迈出了模型训练的第一步,学会了如何读取数据。

今天的重点就是掌握两种读取数据的方法,也就是自定义和读取常用图像数据集

最通用的数据读取方法就是自己定义一个Dataset的派生类。而读取常用的图像数据集就可以利用PyTorch提供的视觉包Torchvision。

Torchvision库为我们读取数据提供了丰富的图像数据集的接口。我用手写数字识别这个经典例子给你示范了如何使用Torchvision来读取MNIST数据集。

torchvision.datasets继承了Dataset 类,它在预定义许多常用的数据集的同时,还预留了数据预处理与数据增强的接口。在下一节课中,我们就会接触到这些数据增强函数,并学习如何进行数据增强。

每课一练

在PyTorch中我们要定义一个数据集应该继承哪一个类呢

欢迎你在留言区和我交流互动,也推荐你把这节课内容分享给更多的朋友、同事,跟他一起学习进步。