关于torch中使用summary无法打印网络结构AttributeError: 'list' object has no attribute 'size'

问题遇到的现象和发生背景

以前打印网络模型,都是用summary函数来打印,但是这次改进的网络报错,不过直接用print(model)能打印出来,这是怎么回事
也用过其他博主提出的解决方法,如下

#修改前
summary[m_key]["input_shape"] = list(input[0].size())
#修改后
if isinstance(input[0], torch.Tensor):
   summary[m_key]["input_shape"] = list(input[0].size())
elif isinstance(input[0], list):
    summary[m_key]["input_shape"] = list(np.array(input[0]).shape)


不过这个针对input出错,而我报错是在output那里,所以一直没有找到解决问题的办法

问题相关代码
from TransUnet import *
from torchsummary import summary

model = get_transNet(2)
print(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(model)
summary(model, (1, 512, 512))

运行结果及报错内容

在用print(model)可以正常打印网络

img

但是用summary却不能正常打印模型结构

img

我想要达到的结果

可以使用summary打印网络结构

你用的是哪个函数

训segformer的时候 x,H, W = self.my_patch_embed1(x)这行也报错了,在summary里print一下 如果是int就pass,HW是Int没法在summary的时候列出来

请问解决了吗 我的也是在output的时候出现问题 [-1] + list(o.size())[1:] for o in output

summary[m_key]['output_shape'] = [[-1] + list(np.array(o,dtype=object).shape if isinstance(o,list) else o.size())[1:] for o in output]
可以这样修改试试