我使用tensorflow将数据集转换为tfrecords格式。数据集主要是包括两个部分,一个就是jpg图像,这个图像直接使用tf.io.read file进行读取,读成bytes就可以顺利的转化为tfrecords,转换后的数据体积不会明显膨胀。另一部分是一个二进制文件,我不得不对他进行处理转换为numpy数组,我通过将np数组转化为byte存入tfrecords后,体积巨大,请问有没有什么好一点的方法能解决呢?
class Cifar(object):
def __init__(self):
# 初始化操作
self.height=32
self.width=32
self.channels=3
# 字节数
self.image_bytes=self.height*self.width*self.channels # 图片像素数
self.label_bytes=1 # 标签数
self.all_bytes=self.label_bytes+self.image_bytes # 总字节数
def read_and_decode(self,file_list):
# 1、构造文件名队列
file_queue=tf.train.string_input_producer(file_list)
# 2、读取与解码
# 读取阶段
reader=tf.FixedLengthRecordReader(self.all_bytes)
# key 文件名,value一个样本
key,value=reader.read(file_queue)
# 解码阶段
decode=tf.decode_raw(value,tf.uint8)
# 将目标值和特征值切片分开,即标签和通道分开。tf.slice(data,起始位置,个数)
label=tf.slice(decode,[0],[self.label_bytes])
image=tf.slice(decode,[self.label_bytes],[self.image_bytes])
# 调整图片形状
image_reshaped=tf.reshape(image,shape = [self.channels,self.height,self.width])
# 转置,转成tf图片的表示格式 height,width,channels
image_transposed=tf.transpose(image_reshaped,[1,2,0])
# 跳转图像类型,uint8转为float32
image_cast=tf.cast(image_transposed,tf.float32)
# 3、批处理
label_batch,image_batch=tf.train.batch([label,image_cast],batch_size = 100,num_threads = 1,capacity = 100)
# 开启会话
with tf.Session() as sess:
print('------------------开启会话------------------')
# 开启线程
coord=tf.train.Coordinator() # 协调器
threads=tf.train.start_queue_runners(sess=sess,coord = coord)
label_batch_new,image_batch_new=sess.run([label_batch,image_batch])
# 回收线程
coord.request_stop()
coord.join(threads)
return label_batch_new,image_batch_new
def write_to_tfrecords(self,label_batch,image_batch):
# 将样本的特征值和目标值写入tfrecords文件
with tf.python_io.TFRecordWriter('./temp/cifar10/cifar10.tfrecords') as tfWriter:
# 循环构造example对象,并序列化写入文件
for i in range(label_batch.size):
image=image_batch[i].tostring() # 序列化
label=label_batch[i][0] # [i][0]取出一维数组的值
example = tf.train.Example(features = tf.train.Features(feature = {
"image": tf.train.Feature(bytes_list = tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list = tf.train.Int64List(value=[label]))
}))
# 将序列化后的example写入到cifar10.tfrecords文件中
tfWriter.write(example.SerializeToString())
if __name__ == '__main__':
file_name=os.listdir('./data/cifar-10-batches-bin')
# 构造路径 + 文件名的列表
file_list=[os.path.join('./data/cifar-10-batches-bin',file) for file in file_name if file[-3:]=='bin']
print('file_llist: ',file_list)
#实例化Cifar类
cifar=Cifar()
label_batch,image_batch=cifar.read_and_decode(file_list)
cifar.write_to_tfrecords(label_batch,image_batch)