PyTorch介绍|DATSETS&DATALOADERS

博客 动态
0 263
羽尘
羽尘 2022-01-28 13:54:37
悬赏:0 积分 收藏

PyTorch 介绍 | DATSETS & DATALOADERS

用于处理数据样本的代码可能会变得凌乱且难以维护;理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性。PyTorch提供了两个data primitives:torch.utils.data.DataLoadertorch.utils.data.Dataset,允许你使用预加载的datasets和你自己的data。Dataset 存储样本及其对应的标签,DataLoaderDataset 包装了一个迭代器,以便访问样本。

PyTorch库提供了一些预加载的数据集(如FashionMNIST),它们是 torch.utils.data.Dataset 的子类,特定的数据对应特定的实现函数。它们可以用来原型化和基准化你的模型。你可以在这里查看它们:Image Datasets, Text Datasets, and Audio Datasets。

加载数据集

这是一个怎样从TorchVision加载Fashion-MNIST数据集的例子。Fashion-MNIST来自于Zalando的文章,由60000张训练样本和10000张测试样本组成。每一个样本包含一个28x28
的灰度图片和对应的10类中的1个类的标签。

我们用以下参数加载FashionMNIST Dataset

  • root 是训练/测试数据的保存路径
  • train 指定是训练集还是测试集
  • download=True 如果 root 中没有,则从网上下载
  • transformtarget_transform 指定样本的变换
import torchfrom torch.utils.data import Datasetfrom torchvision import datasetsfrom torchvision.transforms import ToTensorimport matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(    root='data',    train=True,    download=True,    transform=ToTensor())test_data = datasets.FashionMNIST(    root='data',    train=False,    download=True    transform=ToTensor())

输出:

点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gzDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gzExtracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gzDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gzDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gzExtracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gzDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

迭代和数据集可视化

我们可以像list一样索引Datasets:training_data[index]。使用 matplotlib 可视化一些训练集的样本。

labels_map = {    0: "T-Shirt",    1: "Trouser",    2: "Pullover",    3: "Dress",    4: "Coat",    5: "Sandal",    6: "Shirt",    7: "Sneaker",    8: "Bag",    9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3for i in range(1, cols * rows + 1):    sample_idx = torch.randint(len(training_data), size=(1,)).item()    img, label = training_data[sample_idx]    figure.add_subplot(rows, cols, i)    plt.title(labels_map[label])    plt.axis("off")    # torch.squeeze():删除维数为1的维度    plt.imshow(img.squeeze(), cmap="gray")plt.show()

创建自定义数据集

一个自定义的数据集类必须实现三个函数:init,len,getitem。查看下面的实现过程,FashionMNIST图片保存在 img_dir,它们的标签分别保存在一个CSV文件(逗号分隔值文件) annotations_file 中。

下一节,我们将分解每个函数做了什么的。

import osimport pandas as pdfrom torchvision.io import read_imageclass CustomImageDataset(Dataset):    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):        # 利用pandas读取csv并转换为DataFrame        self.img_labels = pd.read_csv(annotations_file)        self.img_dir = img_dir        self.transform = transform        self.target_transform = target_transform        def __len__(self):        return len(self.img_labels)    def __getitem__(self, idx):        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])        image = read_image(img_path)        label = self.img_labels.iloc[idx, 1]        if self.transform:            image = self.transform(image)        if self.target_transform:            label = self.target_transform(label)        return image, label

init

一旦实例化Datase对象,函数__init__ 就会立即运行:初始化包含图片的目录,标签文件,以及两个转换(下一节有更详细的介绍)

labels.csv类似这样:

tshirt1.jpg, 0tshirt2.jpg, 0...anleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):    # 这里指定了列名    self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'labels'])    self.img_dir = img_dir    self.transform = transform    self.target_transform = target_transform

len

__len__ 函数返回数据集的样本数

例如:

def __len__(self):    return len(self.img_labels)

getitem

__getitem__函数加载和返回数据集中给定索引 idx 的样本。根据索引,它获得了硬盘上图片的位置,利用 read_image 转换为tensor,在 self.img_labels ,从csv中检索相应的标签,并调用转换函数(如果可用),返回一个包含图片和对应标签张量的元组。

def __getitem__(self, idx):    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])    image = read_image(img_path)    label = self.img_labels.iloc[idx, 1]    if self.transform:        image = self.transform(image)    if self.target_transform:        label = self.target_transform(label)    return image, label

利用DataLoader为训练准备你的数据

Dataset只能同时检索一个样本的数据特征和标签。当训练模型时,通常需要传递“minibatches”样本,每一个epoch重复打乱数据减少过拟合,并使用Python的 multiprocessing 加速数据检索。

DataLoader 是一个迭代器。

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

通过DataLoader迭代

我们已经将该数据集加载到 DataLoader,根据需要可以对数据集进行迭代。每次迭代返回一个 train_featurestrain_labels 的batch(分别包含 batch_size=64的特征和标签)。因为我们指定了 shuffle=True, 在我们迭代完所有的batch之后,数据就会被打乱(为了对数据加载顺序进行更细致的控制,参阅Samplers)

# Display image and label.train_features, train_labels = next(iter(train_dataloader))print(f"Feature batch shape: {train_features.size()}")print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"Label: {label}")

输出:

Feature batch shape: torch.Size([64, 1, 28, 28])Labels batch shape: torch.Size([64])Label: 7

延伸阅读

  • torch.utils.data API
posted @ 2022-01-28 12:52 Deep_RS 阅读(6) 评论(0) 编辑 收藏 举报
回帖
    羽尘

    羽尘 (王者 段位)

    2335 积分 (2)粉丝 (11)源码

     

    温馨提示

    亦奇源码

    最新会员