利用 pytorch 加载 mnist 数据集的代码如下
- import torchvision
- import torchvision.transforms as transforms
- from torch.utils.data import DataLoader
- train_data = torchvision.datasets.MNIST(
- root='./mnist/',
- train=True, # this is training data
- transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
- # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
- download=True,
- )
- test_data = torchvision.datasets.MNIST(
- root='./mnist/',
- train=False, # this is training data
- transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
- # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
- download=True,
- )
- train_data_loader = DataLoader(train_data, shuffle=True, batch_size = 100)
- test_data_loader = DataLoader(test_data,shuffle=True, batch_size=100)
第一次使用 mnist, 需要下载, 具体方法就是设置 download=True, 然而我运行的时候报错了.
错误: not gzip file
可是明明是. gz 文件啊, 查了几篇博客也没有说清楚原因的, 于是自行下载了四个文件 (训练集, 测试集以及各自的标签), 放入./mnist/raw 文件夹下, 运行, 报错: 找不到文件.
此时, 系统需要找的是./mnist/process 文件夹下的 train.pt 和 test.pt 文件, 这应该是 pytorch 下载原文件后处理生成的, 可是我无处下载, 于是用另一台电脑下载, 程序没有报错, 我把生成的. pt 文件拷贝过来, 可以运行了
来源: http://www.bubuko.com/infodetail-2769528.html