以前打印网络模型,都是用summary函数来打印,但是这次改进的网络报错,不过直接用print(model)能打印出来,这是怎么回事
也用过其他博主提出的解决方法,如下
#修改前
summary[m_key]["input_shape"] = list(input[0].size())
#修改后
if isinstance(input[0], torch.Tensor):
summary[m_key]["input_shape"] = list(input[0].size())
elif isinstance(input[0], list):
summary[m_key]["input_shape"] = list(np.array(input[0]).shape)
不过这个针对input出错,而我报错是在output那里,所以一直没有找到解决问题的办法
from TransUnet import *
from torchsummary import summary
model = get_transNet(2)
print(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(model)
summary(model, (1, 512, 512))
在用print(model)可以正常打印网络
但是用summary却不能正常打印模型结构
可以使用summary打印网络结构
你用的是哪个函数
训segformer的时候 x,H, W = self.my_patch_embed1(x)这行也报错了,在summary里print一下 如果是int就pass,HW是Int没法在summary的时候列出来
请问解决了吗 我的也是在output的时候出现问题 [-1] + list(o.size())[1:] for o in output
summary[m_key]['output_shape'] = [[-1] + list(np.array(o,dtype=object).shape if isinstance(o,list) else o.size())[1:] for o in output]
可以这样修改试试