tensorflow2+ 版本梯度带断流问题

大致思路如下:
模型A 预测输入x的标签y_pred
模型B 根据输入的y_pred和真实标签y输出一个数值loss_A作为模型A 的损失
根据loss_A计算模型A的梯度,并更新模型A
更新的A重新预测x的标签,为y_pred_new
此时计算y和y_pred_new的交叉熵损失loss,更新模型B,但是在loss对模型B的求梯度时,梯度全为none

请问该如何修改下面代码才能实现这个功能?

代码如下


ModelA(x)  #  param: theta
# 网络# 
return  y_pred


ModelB(y, y_pred)  # param: beta
# 网络 # 
return z  # 该代码中z作为modelA的loss更新A


with tf.GradientTape() as tape_1:
    with tf.GrandientTape() as tape:
        y_pred = ModelA(x)
        loss_A = ModelB(y, y_pred)
    grads = tape.gradient(loss_A, ModelA.trainable_variables)
    optimizer.apply_gradients(zip(grads, ModelA.trainable_variables)
    y_pred_new = ModelA(x)
    loss = categorical_crossentropy(y, y_pred_new)
grads_1 = tape_1.gradient(loss, ModelB.trainable_variables)  #  !!!!此处出现问题,梯度全为none
optimizer.apply_gradients(zip(grads_1, ModelB.trainable_variables)

公式过程大致如下

img

不愿意花时间看你的代码,仅看问题描述比较简单,直接定义AB模型,手动更新和控制梯度即可。