梯度计算为None,为什么啊?

编写如下代码实现MNIST分类的神经网络的训练:

import tensorflow as tf
print(tf.__version__)


batch_size = 128
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = tf.data.Dataset.from_tensor_slices(
        x_train).batch(batch_size)
y_train = tf.data.Dataset.from_tensor_slices(
        y_train).batch(batch_size)

classifier = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=[28, 28]),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 编译
# model.compile(optimizer='adam',
            #   loss='sparse_categorical_crossentropy',
            #   metrics=['accuracy'])
# 训练模型
# model.fit(x_train, y_train, epochs=5, batch_size=64)

# 基于上述代码改写一下啊 上面代码不报错的
cla_loss = tf.keras.metrics.SparseCategoricalCrossentropy()

learning_rate = 1e-4
cla_opt = tf.keras.optimizers.Adam(learning_rate)

for batch_x, batch_y in zip(x_train, y_train):
    with tf.GradientTape() as tape:
        pred_lable = classifier(batch_x)
        loss = cla_loss(batch_y, pred_lable)

        print('loss:', loss)

    gradients = tape.gradient(loss, classifier.trainable_variables)
    print('gradients:', gradients)
    cla_opt.apply_gradients(zip(gradients, classifier.trainable_variables))

报错如下:

gradients: [None, None, None, None]
Traceback (most recent call last):
  File "/home//文档/CAE/text_a.py", line 43, in 
    cla_opt.apply_gradients(zip(gradients, classifier.trainable_variables))
  File "/home//.local/lib/python3.8/site-packages/keras/optimizer_v2/optimizer_v2.py", line 633, in apply_gradients
    grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
  File "/home//.local/lib/python3.8/site-packages/keras/optimizer_v2/utils.py", line 73, in filter_empty_gradients
    raise ValueError(f"No gradients provided for any variable: {variable}. "
ValueError: No gradients provided for any variable: (['dense/kernel:0', 'dense/bias:0', 'dense_1/kernel:0', 'dense_1/bias:0'],). 

为什么计算的梯度为None?

可能梯度爆炸了吧。