tensorflow写入tfrecord文件的问题

这个代码基本是按照tensorflow官方教程里面的代码写的,应该是一模一样了,但是却报错了

def _bytes_feature(value):
    if isinstance(value,type(tf.constant(0))):
        value=value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def feature_to_string(feature):
    strings=feature.SerializerToString()
    return strings


n_boservations=int(1e4)

feature0=np.random.choice([False,True],n_boservations)


feature1=np.random.randint(0,5,n_boservations)

strings=np.array([b'cat',b'dog',b'chicken',b'horse',b'goat'])

feature2=strings[feature1]

feature3=np.random.randn(n_boservations)


#构建Example
def serialize_example(feature0,feature1,feature2,feature3):
    feature={
        'feature0':_int64_feature(feature0),
        'feature1':_int64_feature(feature1),
        'feature2':_bytes_feature(feature2),
        'feature3':_float_feature(feature3)
    }
    example_proto=tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()        



features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))

def generator():
  for features in features_dataset:
    yield serialize_example(*features)

  #对数据集进行处理  
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())



#写入文件
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)       

图片说明

查看features的数据类型,确保传给 serialize_example 的是一个 tuple 对象

yield serialize_example(*features)

我也出现同样的问题,生成serialized_features_dataset都是正确的,但其类型为MapDataset shapes: (), types: tf.string(采用map)或FlatMapDataset shapes: (), types: tf.string(采用generator),写入TFRecord文件时出现类型不匹配,问题可能还是tf.data.experimental.TFRecordWriter还是有点问题,既然还在experimental嘛。直接采用tf.io.TFRecordWriter可以成功,这样看起来感觉更简单,但也要注意将feature0的数据类型(原本为np.bool_)改为bool类型。代码如下:

import tensorflow as tf
import numpy as np

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# The number of observations in the dataset.
n_observations = int(1e4)

# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)

# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)

# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)

def serialize_example(feature0, feature1, feature2, feature3):
    """
    Creates a tf.Example message ready to be written to a file.
    """
    # Create a dictionary mapping the feature name to the tf.Example-compatible data type.
    feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
    }
  # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

filename = 'test.tfrecord'
with tf.io.TFRecordWriter(filename) as writer:
    for n in range(n_observations):
        f0, f1, f2, f3 = feature0[n], feature1[n], feature2[n], feature3[n]
        writer.write(serialize_example(bool(f0), f1, f2, f3))