目前在用pytorch做图像分类,但是在训练和预测程序运行后,发现预测结果全都是同一个分类结果,我怀疑是在定义数据处理类的时候给图像加标签加的不对,但是我小白一枚,实在看不出哪里有错误,还请大神指点!!!
下面是我定义的数据处理类。我的目标是要把我的训练集分14类,比如文件名有a1的标签为0,文件名有a2的标签为1,等等。
#定义数据处理类
class MyDataset(Dataset):
def __init__(self, file_list, dir, mode='train', transform = None):
self.file_list = file_list
self.dir = dir
self.mode= mode
self.transform = transform
if self.mode == 'train':
if 'a1' in self.file_list[0]:
self.label = 0
if 'a2' in self.file_list[0]:
self.label = 1
if 'a3' in self.file_list[0]:
self.label = 2
if 'a4' in self.file_list[0]:
self.label = 3
if 'b1' in self.file_list[0]:
self.label = 4
if 'b2' in self.file_list[0]:
self.label = 5
if 'b3' in self.file_list[0]:
self.label = 6
if 'b4' in self.file_list[0]:
self.label = 7
if 'c1' in self.file_list[0]:
self.label = 8
if 'c2' in self.file_list[0]:
self.label = 9
if 'c3' in self.file_list[0]:
self.label = 10
if 'c4' in self.file_list[0]:
self.label = 11
if 'c6' in self.file_list[0]:
self.label = 12
else:
self.label = 13
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
img = Image.open(os.path.join(self.dir, self.file_list[idx])).convert('L')
if self.transform:
img = self.transform(img)
if self.mode == 'train':
img = img.numpy()
return img.astype('float32'), self.label
else:
img = img.numpy()
return img.astype('float32'), self.file_list[idx]
只有部分代码没有数据,真的很难啊。。