pytorch重写Dataset类,读取的csv数据类型为str,怎样转换为(1,48,48)的矩阵?
class MyDataset(data.Dataset):
def __init__(self, root,transforms=None):
super(MyDataset, self).__init__()
self.root = root
self.transforms = transforms
df_label = pd.read_csv(root, header=None, usecols=[0])
df_path = pd.read_csv(root, header=None, usecols=[1])
self.label = np.array(df_label)[1:, 0]
self.path = np.array(df_path)[1:, 0]
def __getitem__(self, item):
img=self.path[item]
target=self.label[item]
print(type(img), type(img[0]),img)
#img=img.reshape(48,48)
#img = Image.fromarray(img.numpy(), mode='L')
if self.transforms is not None:
img = self.transforms(img)
return img, target
def __len__(self):
return self.path.shape[0]
<class 'str'> <class 'str'> 251 251 251 253 246 217 186 172 162 139 144 113 92 164 209 225 232 234 237 239 237 234 231 233 233 230 228 225 212 203 182 164 148 136 119 108 110 116 129 151 149 129 103 109 99 88 93 87 251 251 251 253 223 193 166 161 136 141 123 80 150 200 219 228 231 236 238 236 237
根据数据类型进行转换一下试试,类似这样:
import numpy as np
s='251 251 251 253 246 217 186 172 162 139 144 113 92 164 209 225 232 234 237 239 237 234 231 233 233 230 228 225 212 203 182 164 148 136 119 108 110 116 129 151 149 129 103 109 99 88 93 87 251 251 251 253 223 193 166 161 136 141 123 80 150 200 219 228 231 236 238 236 237'
a=np.array(list(map(int,s.split()))).reshape(1,3,23)
print(a)
您好,我是有问必答小助手,您的问题已经有小伙伴帮您解答,感谢您对有问必答的支持与关注!