参考《TensorFlow从零开始学》写了一个Fashion-Mnist数据集的判定,但是在使用model.fit进行训练的时候出现了错误,python3.7, TensorFlow2.3.1,cuda10.1,cudnn已安装对应版本,代码全文如下
import tensorflow as tf
import numpy as np
import gzip
def get_data():
# 文件获取
train_image = r"E:/Fashion-MNIST/data/train-images-idx3-ubyte.gz"
test_image = r"E:/Fashion-MNIST/data/t10k-images-idx3-ubyte.gz"
train_label = r"E:/Fashion-MNIST/data/train-labels-idx1-ubyte.gz"
test_label = r"E:/Fashion-MNIST/data/t10k-labels-idx1-ubyte.gz"
paths = [train_label, train_image, test_label,test_image]
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return x_train/255, y_train, x_test/255, y_test
#归一化,除以像素值255
class MnistData:
def __init__(self):
self.train_image,self.train_label,self.test_image,self.test_label=get_data()
if __name__ == '__main__':
data=MnistData()
# 建立网络
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# print(data.test_image_path.shape)
model.fit(data.train_image,data.train_label,epochs=1)
下面为报错信息
Traceback (most recent call last):
File "E:/Fashion-MNIST/main.py", line 50, in <module>
model.fit(data.train_image,data.train_label,epochs=1)
File "E:\python\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "E:\python\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "E:\python\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "E:\python\lib\site-packages\tensorflow\python\eager\def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "E:\python\lib\site-packages\tensorflow\python\eager\def_function.py", line 697, in _initialize
*args, **kwds))
File "E:\python\lib\site-packages\tensorflow\python\eager\function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "E:\python\lib\site-packages\tensorflow\python\eager\function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "E:\python\lib\site-packages\tensorflow\python\eager\function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "E:\python\lib\site-packages\tensorflow\python\framework\func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "E:\python\lib\site-packages\tensorflow\python\eager\def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "E:\python\lib\site-packages\tensorflow\python\framework\func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
E:\python\lib\site-packages\tensorflow\python\keras\engine\training.py:806 train_function *
return step_function(self, iterator)
E:\python\lib\site-packages\tensorflow\python\keras\engine\training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
E:\python\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
E:\python\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
E:\python\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
E:\python\lib\site-packages\tensorflow\python\keras\engine\training.py:789 run_step **
outputs = model.train_step(data)
E:\python\lib\site-packages\tensorflow\python\keras\engine\training.py:749 train_step
y, y_pred, sample_weight, regularization_losses=self.losses)
E:\python\lib\site-packages\tensorflow\python\keras\engine\compile_utils.py:204 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
E:\python\lib\site-packages\tensorflow\python\keras\losses.py:149 __call__
losses = ag_call(y_true, y_pred)
E:\python\lib\site-packages\tensorflow\python\keras\losses.py:253 call **
return ag_fn(y_true, y_pred, **self._fn_kwargs)
E:\python\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
return target(*args, **kwargs)
E:\python\lib\site-packages\tensorflow\python\keras\losses.py:1535 categorical_crossentropy
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
E:\python\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
return target(*args, **kwargs)
E:\python\lib\site-packages\tensorflow\python\keras\backend.py:4687 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
E:\python\lib\site-packages\tensorflow\python\framework\tensor_shape.py:1134 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (32, 1) and (32, 10) are incompatible