数据集和数据加载器

PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset ,让您使用预加载数据集,以及您自己的数据

加载数据集

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data", # 是存储训练/测试数据的路径
    train=True, # 指定训练或测试数据集
    download=True, # 如果数据不可用,则从 Internet 下载数据root
    transform=ToTensor() # transform并target_transform指定特征和标签转换
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

最后更新于