有关torch.load()的问题

问题遇到的现象和发生背景

我在训练数据后保存了模型,同样我想利用此模型进行测试。使用torch.load()发生错误。

问题相关代码,请勿粘贴截图

这是保存模型的代码:

torch.save(net, "net_{}.pth".format(epoch_index))

这是1读取模型的代码(NLEDN是模型的网络结构):

net = NLEDN() 

net.load_state_dict(torch.load('net_9.pth'))
运行结果及报错内容

出现错误:

img

希望给予指导,非常感谢!

torch.save(net.state_dict(), "net_{}.pth".format(epoch_index))
这样保存模型