一个网络的输出作为另一个网络的输入,最后根据损失更新参数,提示我loss.backward()错误

一个网络的输出作为另一个网络的输入,最后根据损失更新参数,提示我loss.backward()错误

img

我试着改成loss.backward(retain_graph=True),虽然两个模型都能更新参数了,但数据量一大,运行就很慢。
有什么办法改进吗?

双线程训练可以不