深度学习出现这种维度问题

return torch.stack(batch, 0, out=out)

RuntimeError: stack expects each tensor to be equal size, but got [1, 312, 312] at entry 0 and [1, 512, 512] at entry 1

  • 建议你看下这篇博客👉 :RuntimeError: stack expects each tensor to be equal size, but got [3, ] at entry 0 and [1,]at entry1
  • 除此之外, 这篇博客: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [1,224,224] at entr中的  验证方法: 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • path_load = glob.glob(r'D:\BaiduNetdiskDownload\pytorch_learning\dataset\dataset2\*.jpg')
    #图片路径
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor()
    ])
    
    all_labels=[]
    species = ['cloudy', 'rain', 'shine', 'sunrise']
    #图片类别
    for img in path_load:
        for i,c in enumerate(species):
            if c in img:
                all_labels.append(i)
    #图片标签            
    species_to_idx = dict((c, i) for i, c in enumerate(species))          
    label_to_class =  dict((v,k) for k,v in species_to_idx.items())           
    class Mydataset(data.Dataset):
        def __init__(self,root,labels,transform):
            super(Mydataset,self).__init__()
            self.imgs_path = root
            self.labels = labels
            self.transform = transform
        def __getitem__(self,index):
            ig_path = self.imgs_path[index]
            label=self.labels[index]
            pil_image = Image.open(ig_path).convert('RGB')
            data = self.transform(pil_image)
            return data,label
        def __len__(self):
            return len(self.imgs_path)
    
    
    wheather_dataset = Mydataset(path_load,all_labels,transform)
    wheather_dl = data.DataLoader(wheather_dataset,
                                  batch_size=16,
                                 shuffle=True,
                                 drop_last=True)
    
    
    
    plt.figure(figsize=(12,8))
    imgs_batch,labels_batch=next(iter(wheather_dl))
    for i,(img,label) in enumerate(zip(imgs_batch,labels_batch)):
            img = img.permute(1,2,0).numpy()
            plt.subplot(4,4,i+1)
            plt.title(label_to_class.get(label.item()))
            plt.imshow(img)

    result: