tensorflow2.x 梯度带求导断流问题

大致思路如下:
模型A 预测输入 x 的标签 y_pred
模型B 根据输入的 y_pred 和真实标签y输出一个数值 z 作为模型A 的损失
根据z计算模型A的梯度,并更新模型A
更新的A重新预测 x 的标签记为 y_pred_new
此时计算 y_pred 和 y_pred_new 的交叉熵损失loss,更新模型B,但是在loss对模型B的求梯度时,梯度全为none

代码如下

ModelA(x)  #  param: θ , 预测输入 x 的标签 y_pred
# 网络# 
return  y_pred
 
ModelB(y, y_pred)  # param: beta,根据输入的 y_pred 和真实标签y输出一个数值 z 作为模型A 的损失
# 网络 # 
return z  
 
with tf.GradientTape() as tape_1:
    with tf.GrandientTape() as tape:
        y_pred = ModelA(x)                                   
        z = ModelB(y, y_pred)                                
    grads = tape.gradient(z, ModelA.trainable_variables)
    optimizer.apply_gradients(zip(grads, ModelA.trainable_variables)   # 根据z计算模型A的梯度,并更新模型A 
    y_pred_new = ModelA(x)                                  # 更新的A重新预测 x 的标签记为y_pred_new
    loss = categorical_crossentropy(y, y_pred_new)        # 计算 y_pred 和 y_pred_new 的交叉熵损失loss,用以更新模型B
grads_1 = tape_1.gradient(loss, ModelB.trainable_variables)   #  !!!!此处出现问题,梯度全为none
optimizer.apply_gradients(zip(grads_1, ModelB.trainable_variables)

公式过程如下图
推测问题在于③位置上的对θ更新时求导,因为grads=tape.gradient()求出来是tensor,相当于βx变成了tensor,不是variable了,导致⑤位置求导的时候无法对β求导,我该如何解决这个问题?

img