pytorch重写Dataset类,用于读取csv数据

问题遇到的现象和发生背景

pytorch重写Dataset类,读取的csv数据类型为str,怎样转换为(1,48,48)的矩阵?

问题相关代码,请勿粘贴截图
class MyDataset(data.Dataset):
    def __init__(self, root,transforms=None):
        super(MyDataset, self).__init__()
        self.root = root
        self.transforms = transforms
        df_label = pd.read_csv(root, header=None, usecols=[0])
        df_path = pd.read_csv(root, header=None, usecols=[1])
        self.label = np.array(df_label)[1:, 0]
        self.path = np.array(df_path)[1:, 0]

    def __getitem__(self, item):
        img=self.path[item]
        target=self.label[item]
        print(type(img), type(img[0]),img)
        #img=img.reshape(48,48)
        #img = Image.fromarray(img.numpy(), mode='L')
        if self.transforms is not None:
            img = self.transforms(img)
        return img, target

    def __len__(self):
        return self.path.shape[0]

运行结果及报错内容

<class 'str'> <class 'str'> 251 251 251 253 246 217 186 172 162 139 144 113 92 164 209 225 232 234 237 239 237 234 231 233 233 230 228 225 212 203 182 164 148 136 119 108 110 116 129 151 149 129 103 109 99 88 93 87 251 251 251 253 223 193 166 161 136 141 123 80 150 200 219 228 231 236 238 236 237

我的解答思路和尝试过的方法
我想要达到的结果




根据数据类型进行转换一下试试,类似这样:

import numpy as np
s='251 251 251 253 246 217 186 172 162 139 144 113 92 164 209 225 232 234 237 239 237 234 231 233 233 230 228 225 212 203 182 164 148 136 119 108 110 116 129 151 149 129 103 109 99 88 93 87 251 251 251 253 223 193 166 161 136 141 123 80 150 200 219 228 231 236 238 236 237'
a=np.array(list(map(int,s.split()))).reshape(1,3,23)
print(a)

您好,我是有问必答小助手,您的问题已经有小伙伴帮您解答,感谢您对有问必答的支持与关注!
PS:问答VIP年卡 【限时加赠:IT技术图书免费领】,了解详情>>> https://vip.csdn.net/askvip?utm_source=1146287632