tensorflow中datasets.map()报错

如下图所示,其功能是将代码封装进dataset并转化我可迭代格式,但是在执行预处理map()函数的时候报错:

ValueError: Tensor conversion requested dtype float32 for Tensor with dtype uint8: 'Tensor("arg0:0", shape=(28, 28), dtype=uint8)'

然而在删除map后正常运行,说明不是转换格式的问题,求问各位大神这是为什么呢?

报错代码:

(x,y),(x_val,y_val)=datasets.mnist.load_data()

def trans(x,y):
    x=tf.convert_to_tensor(x,dtype=tf.float32)
    y=tf.convert_to_tensor(y,dtype=tf.int32)
    y=tf.one_hot(y,depth=10)
    return x,y

train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db.map(trans)
train_db.shuffle(10000).batch(32)

正常运行:

(x,y),(x_val,y_val)=datasets.mnist.load_data()

x=tf.convert_to_tensor(x,dtype=tf.float32)
y=tf.convert_to_tensor(y,dtype=tf.int32)
y=tf.one_hot(y,depth=10)

train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db.shuffle(10000).batch(32)

error提示说的是数据类型不匹配,'Tensor("arg0:0", shape=(28, 28), dtype=uint8)'应该是在说变量x的类型是uint8,但是你定义的是float32

 x=tf.convert_to_tensor(x,dtype=tf.float32)

你可以尝试改一下dtype,比如说dtype=tf.int8

我也是新手所以不知道对不对,你可以试试哈

引为train_db已经为tensor类型,所以应该使用cast函数转换数据类型,而不是使用convert_to_tensor函数;