利用tensorflow识别minist中的0和1,遇到如下报错如何解决?

报错:ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (28, 2)).
代码如下:


import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import data

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train01=[]
y_train01=[]
x_test01=[]
y_test01=[]
for (i,j) in zip(x_train, y_train):
    if(j==0 or j==1):
        x_train01.append(i)
        y_train01.append(j)
print(y_train01)
for (i, j) in zip(x_test, y_test):
    if (j == 0 or j == 1):
        x_test01.append(i)
        y_test01.append(j)


train_dataset = data.Dataset.from_tensor_slices((x_train01, y_train01))
test_dataset = data.Dataset.from_tensor_slices((x_test01, y_test01))

print(test_dataset)
#配置网络
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(28, activation='relu'),
    tf.keras.layers.Dense(2, activation='softmax')
])
#配置训练参数
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

#训练模型
model.fit(train_dataset, batch_size=32, epochs=5, validation_data=test_dataset, validation_freq=1)
model.summary()

除了最后一个维度(received (28,2)),标签的形状(received (1,))应该等于logits的形状。