tensorflow制作tfrecord格式数据

  • 2019 年 11 月 8 日
  • 筆記

tf.Example msg

tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据.

To read data efficiently it can be helpful to serialize your data and store it in a set of files (100-200MB each) that can each be read linearly. This is especially true if the data is being streamed over a network. This can also be useful for caching any data-preprocessing.

The TFRecord format is a simple format for storing a sequence of binary records.
protobuf消息的格式如下:
https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/core/example/feature.proto

message BytesList {    repeated bytes value = 1;  }  message FloatList {    repeated float value = 1 [packed = true];  }  message Int64List {    repeated int64 value = 1 [packed = true];  }    // Containers for non-sequential data.  message Feature {    // Each feature can be exactly one kind.    oneof kind {      BytesList bytes_list = 1;      FloatList float_list = 2;      Int64List int64_list = 3;    }  };    message Features {    map<string, Feature> feature = 1;  };    message FeatureList {    repeated Feature feature = 1;  };    message FeatureLists {    map<string, FeatureList> feature_list = 1;  };

tf.Example是一个map. map的格式为{"string": tf.train.Feature}
tf.train.Feature基本的格式有3种:

  • tf.train.BytesList
    • string
    • byte
  • tf.train.FloatList
    • float(float32)
    • double(float64)
  • tf.train.Int64List
    • bool
    • enum
    • int32
    • unit32
    • int64
    • uint64

参考tensorflow官方文档

将自己的数据制作为tfrecord格式

完整代码

from __future__ import absolute_import, division, print_function, unicode_literals  import tensorflow as tf  import numpy as np  import IPython.display as display  import os  import cv2 as cv  import argparse    def _bytes_feature(value):    """Returns a bytes_list from a string / byte."""    if isinstance(value, type(tf.constant(0))):      value = value.numpy() # BytesList won't unpack a string from an EagerTensor.    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))    def _float_feature(value):    """Returns a float_list from a float / double."""    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))    def _int64_feature(value):    """Returns an int64_list from a bool / enum / int / uint."""    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))    def convert_to_tfexample(img,label):      """convert one img matrix into tf.Example"""      img_raw = img.tostring()      example = tf.train.Example(features=tf.train.Features(feature={      'label': _int64_feature(label),      'img': _bytes_feature(img_raw)}))        return example    #path="/home/sc/disk/data/lishui/1"  def read_dataset(path):      imgs=[]      labels=[]      for root, dirs, files in os.walk(path):          for one_file in files:              #print(os.path.join(path,one_file))              one_file = os.path.join(path,one_file)              if one_file.endswith("png"):                  label_file = one_file.replace('png','txt')                  if not os.path.isfile(label_file):                      continue                    f = open(label_file)                  class_index = int(f.readline().split(' ')[0])                  labels.append(class_index)                    img = cv.imread(one_file)                  imgs.append(img)        return imgs,labels    def arg_parse():      parser = argparse.ArgumentParser()      #parser.add_argument('--help',help='ex:python create_tfrecord.py -d /home/sc/disk/data/lishui/1 -o train.tfrecord')      parser.add_argument('-d','--dir',type=str,default='./data',required='True',help='dir store images/label file')      parser.add_argument('-o','--output',type=str,default='./outdata.tfrecord',required='True',help='output tfrecord file name')        args = parser.parse_args()        return args    def main():      args = arg_parse()        writer = tf.io.TFRecordWriter(args.output)      #path="/home/sc/disk/data/lishui/1"        imgs,labels = read_dataset(args.dir)      examples = map(convert_to_tfexample,imgs,labels)      for example in examples:          writer.write(example.SerializeToString())      writer.close()        print("write done")    if __name__ == '__main__':      """      usage:python create_tfrecord.py [data_path] [outrecordfile_path]      ex:python create_tfrecord.py -d /home/sc/disk/data/lishui/1 -o train.tfrecord      """      main()      

首先就是需要有工具函数把byte/string/float/int..等等类型的数据转换为tf.train.Feature

def _bytes_feature(value):    """Returns a bytes_list from a string / byte."""    if isinstance(value, type(tf.constant(0))):      value = value.numpy() # BytesList won't unpack a string from an EagerTensor.    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))    def _float_feature(value):    """Returns a float_list from a float / double."""    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))    def _int64_feature(value):    """Returns an int64_list from a bool / enum / int / uint."""    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

接下来,对于图片矩阵和标签数据,我们调用上述工具函数,将单幅图片及其标签信息转换为tf.ttrain.Example消息.

def convert_to_tfexample(img,label):      """convert one img matrix into tf.Example"""      img_raw = img.tostring()      example = tf.train.Example(features=tf.train.Features(feature={      'label': _int64_feature(label),      'img': _bytes_feature(img_raw)}))        return example

对于我的数据,图片以及label文件位于同一目录.比如dir下有图片a.png及相应的标签信息a.txt.

def read_dataset(path):      imgs=[]      labels=[]      for root, dirs, files in os.walk(path):          for one_file in files:              #print(os.path.join(path,one_file))              one_file = os.path.join(path,one_file)              if one_file.endswith("png"):                  label_file = one_file.replace('png','txt')                  if not os.path.isfile(label_file):                      continue                    f = open(label_file)                  class_index = int(f.readline().split(' ')[0])                  labels.append(class_index)                    img = cv.imread(one_file)                  imgs.append(img)        return imgs,labels

遍历data目录,完成图片读取,及label读取. 如果你的数据不是这么存放的,就修改这个函数好了,返回值仍然是imgs,labels

最后就是调用 tf.io.TFRecordWriter将每一个tf.train.Example消息写入文件保存.

def main():      args = arg_parse()        writer = tf.io.TFRecordWriter(args.output)      #path="/home/sc/disk/data/lishui/1"        imgs,labels = read_dataset(args.dir)      examples = map(convert_to_tfexample,imgs,labels)      for example in examples:          writer.write(example.SerializeToString())      writer.close()        print("write done")