RuntimeError: output with shape [] doesn't match the broadcast shape [3, 189, 275]请各位看看。

RuntimeError: output with shape [] doesn't match the broadcast shape [3, 189, 275]

TestDataset
class TestDataset(data.Dataset):
print("调用__init__")
def init(self, dirname, scale):
super(TestDataset, self).init()

    self.name  = dirname.split("/")[-1]
    self.scale = scale
    
    if "DIV" in self.name:
        self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
        print("self.hr", self.hr)
        self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname), 
                                         "X{}/*.png".format(scale)))
        print("self.lr", self.lr)
    else:
        all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
        self.hr = [name for name in all_files if "HR" in name]
        self.lr = [name for name in all_files if "LR" in name]

    self.hr.sort()
    self.lr.sort()

    self.transform = transforms.Compose([
        transforms.ToTensor()  # 归一化到(0,1),简单直接除以255  range [0, 255] -> [0.0,1.0]
    ])
    print("__init__结束")
print("调用getitem")
def __getitem__(self, index):
    hr = Image.open(self.hr[index])
    print("hr",hr)
    lr = Image.open(self.lr[index])
    print("lr",lr)
    hr = hr.convert("RGB")
    lr = lr.convert("RGB")
    print("hr,lr",hr,lr)
    filename = self.hr[index].split("/")[-1]
    print("filename", filename)
    print("geititem返回的参数hr", self.transform(hr))
    print("geititem返回的参数lr", self.transform(lr))
    print("geititem返回的参数", filename)
    return self.transform(hr), self.transform(lr), filename
print("调用__len__")
def __len__(self):
    return len(self.hr)
print("__len__结束")

sample*
def sample(net, device, dataset, cfg):
scale = cfg.scale
print("scale",scale)
for step, (hr, lr, name) in enumerate(dataset):
print("hr,lr,name",hr,lr,name)
if "DIV2K" in dataset.name:
t1 = time.time()
print("t1",t1)

        h, w = lr.size()[1:]
        print("h,w",h,w)

        h_half, w_half = int(h/2), int(w/2)
        print("h_half,w_half",h_half,w_half)

        h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave
        print("h_chop,w_chop",h_chop,w_chop)

        lr_patch = torch.tensor((4, 3, h_chop, w_chop), dtype=torch.float)
        print("lr_patch",lr_patch)
        lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
        lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
        lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
        lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
        lr_patch = lr_patch.to(device)
        
        sr = net(lr_patch, cfg.scale).detach()
        
        h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
        w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale

        result = torch.tensor((3, h, w), dtype=torch.float).to(device)
        result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
        result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
        result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
        result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
        sr = result
        t2 = time.time()
    else:
        t1 = time.time()
        lr = lr.unsqueeze(0).to(device)
        sr = net(lr, cfg.scale).detach().squeeze(0)
        lr = lr.squeeze(0)
        t2 = time.time()
    
    model_name = cfg.ckpt_path.split(".")[0].split("/")[-1]
    sr_dir = os.path.join(cfg.sample_dir,
                          model_name, 
                          cfg.test_data_dir.split("/")[-1],
                          "x{}".format(cfg.scale),
                          "SR")
    hr_dir = os.path.join(cfg.sample_dir,
                          model_name, 
                          cfg.test_data_dir.split("/")[-1],
                          "x{}".format(cfg.scale),
                          "HR")
    
    os.makedirs(sr_dir, exist_ok=True)
    os.makedirs(hr_dir, exist_ok=True)

    sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR")))
    hr_im_path = os.path.join(hr_dir, "{}".format(name))

    save_image(sr, sr_im_path)
    save_image(hr, hr_im_path)
    print("Saved {} ({}x{} -> {}x{}, {:.3f}s)"
        .format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1))

main
def main(cfg):
module = importlib.import_module("model.{}".format(cfg.model))
net = module.Net(multi_scale=True,
group=cfg.group)
print(json.dumps(vars(cfg), indent=4, sort_keys=True))

state_dict = torch.load(cfg.ckpt_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k
    # name = k[7:] # remove "module."
    new_state_dict[name] = v

net.load_state_dict(new_state_dict)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)

dataset = TestDataset(cfg.test_data_dir, cfg.scale)
print("dataset",dataset)
sample(net, device, dataset, cfg)

** 错误**

img

建议你看下这篇博客RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

https://blog.csdn.net/qq_41804812/article/details/124163950
这篇文章看看有没有帮助