深度学习测试时,加载模型出问题。

代码:

 concat_mask = True if 'MST_shanghaitech' in args.PATH else False
    model = MST(config, concat_mask)
    model.load()
    model.inference(args.image_path, args.mask_path, config.valid_th, config.mask_th,
                    not_obj_remove=args.not_obj_remove)

报错结果为:

Traceback (most recent call last):
  File "test_single.py", line 52, in <module>
    model.load()
  File "E:\code\MST_inpainting-main\src\MST_model.py", line 102, in load
    self.inpaint_decoder.generator.load_state_dict(
  File "D:\Anaconda3\envs\torch18\lib\site-packages\torch\nn\modules\module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for InpaintGateGenerator:
        size mismatch for encoder.1.gate_conv.weight: copying a param with shape torch.Size([128, 6, 7, 7]) from checkpoint, the shape in current model is torch.Size([128, 7, 7, 7]).

这个该怎么去改它的参数呢?

ckp和模型的维度数目不匹配,具体的你可以看看这个看下能不能改
https://blog.csdn.net/qq_45128278/article/details/116588153

怎么感觉缩进不对。。 concat_mask 这个参数是后来加的么 原来这个MST模型里面有么,加载这个模型要保持和原来参数一致吧