TensorDatase

img

ImportError: cannot import name 'TensorDataset' from 'torch.utils.data' (unknown location)请问这是什么原因,搜也搜不到

你的torch.utils.data下可能没有TensorDataset,所以你改成这样试下,只导入Dataset,然后自己写个TensorDataset即可:

import torch
from torch.utils.data import Dataset


class TensorDataset(Dataset):
    """
    TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
    实现将一组Tensor数据对封装成Tensor数据集
    能够通过index得到数据集的数据,能够通过len,得到数据集大小
    """
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)

# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)

# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)