TensorFlow基础(二)

  • 2019 年 10 月 6 日
  • 筆記

前言

在pandas中也可以读取数据,但是他存在的问题(仅仅以我们学过的角度来看)有:

1、一次读取数据,消耗内存

2、一次性进行训练

而在tensorflow中提供多线程,并行的执行任务,队列(数据的共享),文件(tfrecords)的方式读取数据。来提高解析速度。

你可能听过在Python中没有真正的多线程,原因是python存在GIL锁。但是你可能还不知道在numpy中释放了GIL锁,而机器学习库都是基于numpy的。

因此在学习tensorflow读取数据欠还要了解队列和线程。

队列和线程(tensorflow中)

队列

在tensorflow中对队列进行了封装:

tf.FIFOQueue(capacity,dtypes,name='info_queue')

# 前进先出队列,按顺序出队列

capacity:整数,可能存储此队列的元素上限

简单队列Demo:

# 1、定义队列  # 最多放5个数据,类型为float32  Q = tf.FIFOQueue(5,tf.float32)  # 放入一些数据,这里存放列表写成[[1,2,3,4,5],],防止认为是张量  e_many = Q.enqueue_many([[1,2,3,4,5],])  # 取出数据,乘以2  out_q = Q.dequeue()  data = out_q*2  # 再放入队列  en_q = Q.enqueue(data)  with tf.Session() as sess:      # 初始化队列      sess.run(e_many)      # 处理数据100次,这里只用运行en_q,就会运行      # 去除数据,*2,放入队列这三步操作,Tensorflow中操作具有依赖性      # 运行en_q,en_q依赖data,data依赖out_q      for i in range(100):          sess.run(en_q)      # 取数据      for i in range(Q.size().eval()):          print(sess.run(out_q))  

tf.RandomShuffleQueue:随机出队列

用到的时候再说。

线程

叫做队列管理器,但是是创建线程的作用。

tf.train.QueueRunner(queque,enqueue_ops=None)

queue: 一个队列

enqueue_ops:添加线程的队列操作列表,[]*2代表创建2个线程,[]中写操作

在sess中启动线程

create_threads(sess,coord=None,start=False)

coord:线程协调器,当结束后回收子线程

start:True启动线程,如果为False,还需要条用start()启动线程。

Demo实例:

# 模拟实现子线程读取数据,而主线程进行训练,二者并行    # 定义一个队列  Q = tf.FIFOQueue(1000,tf.float32)  # 定义子线程需要做的事情 循环加1 放入队列  var = tf.Variable(0.0)  # 每次加1,如果不使用assign_add,每次都是0+1会一直是1  data = tf.assign_add(var,1.0)  # 放入队列  en_q = Q.enqueue(data)  # 定义队列管理器op,指定线程做什么  qr = tf.train.QueueRunner(Q,enqueue_ops=[en_q]*2)  # 初始化变量op  init_op = tf.global_variables_initializer()    with tf.Session() as sess:      # 初始化变量      sess.run(init_op)      # 开启线程协调器,当主线程结束回收子线程      coord = tf.train.Coordinator()      # 开启子线程      threads = qr.create_threads(sess,coord=coord,start=True)      # 主线程读取数据,训练      for i in range(1000):          print(sess.run(Q.dequeue()))      # 回收子线程      coord.request_stop()      coord.join(threads)  

注意:其实以上过程以后都不需要自己写。但是要了解。

文件读取

文件读取流程

1、构建一个文件队列

2、读取队列内容

3、解码

4、批处理

文件读取api介绍

构造文件队列

tf.train.string_input_producer(string_tensor)

string_tensor:含有文件名的1阶张量

读取文件内容(不同文件,读取api不同)

文本,csv文件读取:tf.TextLineReader,按行读取

二进制文件:tf.FixedLengthRecordReader(record_bytes)

record_bytes:整型,指定每次读取的字节数

Tfrecords文件:tf.TFrecordReader

解码

解码csv文件:tf.decode_csv(records,record_defaults=None,dileld_delim=None)

将csv转换成张量,和tf.TextLineReader搭配使用。

records:读取的内容

dileld_delim:分隔符,默认为,

record_defaults:张量类型,设置缺少默认值.

解码二进制:tf.decode_raw()

csv文件读取Demo

def csvread(filelist):      # 构造文件队列,返回的是一个队列      file_queue = tf.train.string_input_producer(filelist,shuffle=False)      # 构造csv阅读器读取队列数据,默认按照行读取      reader = tf.TextLineReader()      # 得到读取的数据key是读取的文件名,value是读取的数据      key, value = reader.read(file_queue)      # 解码      # record_defaults指定读取的文件每一列的类型      # 比如csv的第一列数据是1,2,3,第二列为python,java,C      # 也就是第一列为float,第二列为string      # record_defaults就是指定每列的类型,和默认值      # 1.0为float,说明数据第一列为float类型,默认值是1      # ,""为string类型,说明第二列为string类型,默认值是None      records = [[1.0],["None"]]      # 返回为每个列的每个值      rad_num,label = tf.decode_csv(value,record_defaults=records)      # 读取多个数据,批处理      # 参数一:批处理的值      # 参数二:每批次读取多少数据      # 参数三:开启多少线程      # 参数四:队列的大小      rad_num_batch,label_batch = tf.train.batch([rad_num,label],batch_size=9,num_threads=1,capacity=9)      return rad_num_batch,label_batch    if __name__ == "__main__":      # 自己创建csv文件,列数不必太多      # 将文件放入列表      import os      # 去除警告消息      os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'      # os.listdir:返回该目录下文件名的列表      file_name = os.listdir("./csvdata/")      # 拼接路径      filelist = [os.path.join("./csvdata",file) for file in file_name]      rad_num_batch,label_batch = csvread(filelist)      # 开启会话      with tf.Session() as sess:          #  定义线程协调器          coord = tf.train.Coordinator()          # 开启读取文件的线程,不用上面那样麻烦了          threads = tf.train.start_queue_runners(sess, coord=coord)          # 打印读取的内容          print(sess.run([rad_num_batch,label_batch]))          # 回收线程          coord.request_stop()          coord.join(threads)  

读取图像和二进制下篇见。

Exit mobile version