运行mixmatch源码CIFAR10数据集时报错AttributeError: 'CIFAR10' object has no attribute 'targets',是怎么回事?

  1. 在运行mixmatch程序的时候,用torchvision.datasets载入CIFAT10的时候出现AttributeError: 'CIFAR10' object has no attribute 'targets',错误

还有一个问题就是:由于用torchvision下载太慢,我先把数据集下下来了,然后放在了data目录下面,这个对结果会有影响嘛?
希望大家可以给点建议和意见,谢谢。

加载数据集的代码如下:

def get_cifar10(root, n_labeled,
                 transform_train=None, transform_val=None,
                 download=True):

    base_dataset = torchvision.datasets.CIFAR10(root, train=True, target_transform=True, download=download,)
    train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))

    train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train)
    train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train))
    val_dataset = CIFAR10_labeled(root, val_idxs, train=True, transform=transform_val, download=True)
    test_dataset = CIFAR10_labeled(root, train=False, transform=transform_val, download=True)

    print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset
def train_val_split(labels, n_labeled_per_class):
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    val_idxs = []

    for i in range(10):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
        val_idxs.extend(idxs[-500:])
    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)
    np.random.shuffle(val_idxs)

    return train_labeled_idxs, train_unlabeled_idxs, val_idxs

错误信息如下
(base) D:\CSStudy\PycharmProject\MixMatch-pytorch-master>python train.py --gpu 0 --n-labeled 250 --out cifar10@250
==> Preparing cifar10
Using downloaded and verified file: ./data\cifar-10-python.tar.gz
Traceback (most recent call last):
File "train.py", line 431, in
main()
File "train.py", line 88, in main
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transf
orm_val=transform_val)
File "D:\CSStudy\PycharmProject\MixMatch-pytorch-master\dataset\cifar10.py", line 21, in get_cifar10
train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10))
AttributeError: 'CIFAR10' object has no attribute 'targets'

应该是torch版本的问题,不同torch对应的后缀不同,我正在尝试修改这个问题,推荐查一下torch英文手册

我之前的问题是torch版本问题过低,我需要将data变成train_data
把targets变成train_labels,就好了,希望对你有帮助~