import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import matplotlib.pyplot as plt
import numpy as np
# 展示图像的函数
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()
# 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
AttributeError Traceback (most recent call last)
Input In [3], in <cell line: 14>()
12 # 获取随机数据
13 dataiter = iter(trainloader)
---> 14 images, labels = dataiter.next()
16 # 展示图像
17 imshow(torchvision.utils.make_grid(images))
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute 'next'
_MultiProcessingDataLoaderIter作为DataLoader的iters,应该具有next属性啊
目前torch: 1.13.0+cu117 torchvision : 0.14.0+cu117
next()函数实际上调用了传入函数的.__next()__成员函数。所以,如果传入的函数没有这个成员,则会报错
参考代码理解
import torch
# 生成一些测试数据
X = torch.normal(0, 1, (1000, 2)) # x: sample size = 1000, feature_dim = 2
y = torch.normal(0, 1, (1000, 1)) # y: sample size = 1000, dim = 1
# 定义一个函数,返回dataloader
def load_array(data_and_label, batch_size, is_train=True):
"""Construct a PyTorch data iterator."""
dataset = data.TensorDataset(*data_and_label)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((X, y), batch_size)
#这句会报错:next(data_iter)
next(iter(data_iter))
"""
输出前10组数据
"""
这里,为什么 next(data_iter) 报错,而 next(iter(data_iter)) 可以返回数据呢?这是因为,pytorch的DataLoader函数没有 next 成员,但有 iter 成员(见源文件)。所以,需要首先通过 iter() 函数返回一个 iter 成员,再找这个 iter 的 next
如果官方运行没问题的话,应该是torch的问题,你这torch1.13是最新的版本,有些api之类的变动是正常的事情,如果你只是学习,我比较建议使用1.8.2的LST版本,至少所有的库都会考虑优先支持LST的版本。
另外你这个报的错多线程没有next错误,将num_workers设置为0看下,我设置为在1.8.2下面设置成0就可以了
直接copy你的代码直接报错啊?“trainloader”哪声明的?