在一下代码中,dataset有dataset.len,而经过random_split划分为train_dataset和test_dataset后,train_dataset与test_dataset就没有dataset.len了
mport torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import Module, Linear, Sigmoid, BCELoss
from torch.utils.data import Dataset, DataLoader, random_split
"""setting parameters"""
batchSize = 32
num_works = 2
epochSize = 100
'''define class Datasets'''
class Diabetes(Dataset): # type:Dataset
def __init__(self, filepath):
super(Diabetes, self).__init__()
data = np.loadtxt(filepath, delimiter=",", encoding="utf-8", dtype=np.float32, skiprows=1)
self.len = data.shape[0] # size of the dataset
self.texts = torch.from_numpy(data[:, :-1])
self.labels = torch.from_numpy(data[:, [-1]])
def __len__(self):
return self.len
def __getitem__(self, item):
return self.texts[item], self.labels[item] # texts, labels
dataset = Diabetes("diabetes.csv")
train_dataset, test_dataset = random_split(dataset, [round(0.8*dataset.len), round(0.2*dataset.len)])
在PyTorch中,random_split
函数返回的是Subset
对象,而不是原始的Dataset
对象。Subset
对象只是原始数据集的一个子集,它并不包含原始数据集的所有属性。因此,你在Subset
对象上找不到len
属性。
但是,你仍然可以通过调用Python内置的len()
函数来获取Subset
对象的长度。例如:
print(len(train_dataset))
print(len(test_dataset))
这是因为Subset
类实现了__len__()
方法,所以你可以使用len()
函数来获取它的长度。
如果你需要在Subset
对象上使用其他的自定义属性或方法,你可能需要自己实现一个数据集划分的函数,而不是使用random_split
。或者,你可以在划分数据集后,将需要的属性或方法添加到Subset
对象上。例如:
train_dataset.len = len(train_dataset)
test_dataset.len = len(test_dataset)
这样,你就可以在train_dataset
和test_dataset
上使用len
属性了。但是请注意,这种方法可能会使代码变得难以理解和维护,因此只有在必要的时候才使用。