数据集划分后属性丢失

在一下代码中,dataset有dataset.len,而经过random_split划分为train_dataset和test_dataset后,train_dataset与test_dataset就没有dataset.len了


mport torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import Module, Linear, Sigmoid, BCELoss
from torch.utils.data import Dataset, DataLoader, random_split

"""setting parameters"""
batchSize = 32
num_works = 2
epochSize = 100

'''define class Datasets'''


class Diabetes(Dataset):  # type:Dataset
    def __init__(self, filepath):
        super(Diabetes, self).__init__()
        data = np.loadtxt(filepath, delimiter=",", encoding="utf-8", dtype=np.float32, skiprows=1)
        self.len = data.shape[0]  # size of the dataset
        self.texts = torch.from_numpy(data[:, :-1])
        self.labels = torch.from_numpy(data[:, [-1]])

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.texts[item], self.labels[item]  # texts, labels


dataset = Diabetes("diabetes.csv")
train_dataset, test_dataset = random_split(dataset, [round(0.8*dataset.len), round(0.2*dataset.len)])

在PyTorch中,random_split函数返回的是Subset对象,而不是原始的Dataset对象。Subset对象只是原始数据集的一个子集,它并不包含原始数据集的所有属性。因此,你在Subset对象上找不到len属性。

但是,你仍然可以通过调用Python内置的len()函数来获取Subset对象的长度。例如:

print(len(train_dataset))
print(len(test_dataset))

这是因为Subset类实现了__len__()方法,所以你可以使用len()函数来获取它的长度。

如果你需要在Subset对象上使用其他的自定义属性或方法,你可能需要自己实现一个数据集划分的函数,而不是使用random_split。或者,你可以在划分数据集后,将需要的属性或方法添加到Subset对象上。例如:

train_dataset.len = len(train_dataset)
test_dataset.len = len(test_dataset)

这样,你就可以在train_datasettest_dataset上使用len属性了。但是请注意,这种方法可能会使代码变得难以理解和维护,因此只有在必要的时候才使用。