paddle数据封装问题,x的维度是(28224,480,1),y的维度是(1,480,1),定义了如下的类,我用train_dataset=MyDataset(x,y),得到的train_dataset是什么格式的维度的?
class MyDataset(paddle.io.Dataset):#飞浆框架中,构建数据集类,必须集成paddle.io.Dataset父类
def __init__(self, x,y):#类的初始化方法
self.x = x
self.y = y
def __getitem__(self, index):
#获取第index行数据与标签
a=self.x
b=self.y
return a, b#返回第index行的数据与标签
def __len__(self):#获取数据集长度
return len(self.x)#返回数据集长度
train_dataset=MyDataset(x,y)
这样才能返回想要的index数据,记得采纳哈
from paddle.io import Dataset
from paddle.io import DataLoader
import numpy as np
class LoadImg(Dataset):
def __init__(self, x,y):
"""载入图像数据"""
super(LoadImg, self).__init__()
self.x = x
self.y = y
def __getitem__(self, index):
data = self.x[index]
label= self.y[index]
return data,label
def __len__(self):
return self.x.shape[0]
# 测试方法可行性
def test_img():
x = np.random.random((20,3,448,448))
y = np.random.randint(0,10,(20))
train_dataset = LoadImg(x,y)
inputs, label = train_dataset.__getitem__(10)
print(inputs.shape, type(inputs))
print(label, type(label))
loader_train_t = DataLoader(train_dataset, batch_size=3, shuffle=True, drop_last=False)
for (inputs, label) in loader_train_t:
print(np.reshape(label.numpy(), (-1)), inputs.shape)
if __name__ == "__main__":
test_img()
没问题啊