如下图所示,其功能是将代码封装进dataset并转化我可迭代格式,但是在执行预处理map()函数的时候报错:
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype uint8: 'Tensor("arg0:0", shape=(28, 28), dtype=uint8)'
然而在删除map后正常运行,说明不是转换格式的问题,求问各位大神这是为什么呢?
报错代码:
(x,y),(x_val,y_val)=datasets.mnist.load_data()
def trans(x,y):
x=tf.convert_to_tensor(x,dtype=tf.float32)
y=tf.convert_to_tensor(y,dtype=tf.int32)
y=tf.one_hot(y,depth=10)
return x,y
train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db.map(trans)
train_db.shuffle(10000).batch(32)
正常运行:
(x,y),(x_val,y_val)=datasets.mnist.load_data()
x=tf.convert_to_tensor(x,dtype=tf.float32)
y=tf.convert_to_tensor(y,dtype=tf.int32)
y=tf.one_hot(y,depth=10)
train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db.shuffle(10000).batch(32)
error提示说的是数据类型不匹配,'Tensor("arg0:0", shape=(28, 28), dtype=uint8)'应该是在说变量x的类型是uint8,但是你定义的是float32
x=tf.convert_to_tensor(x,dtype=tf.float32)
你可以尝试改一下dtype,比如说dtype=tf.int8
我也是新手所以不知道对不对,你可以试试哈
引为train_db已经为tensor类型,所以应该使用cast函数转换数据类型,而不是使用convert_to_tensor函数;