简介
对于 文件名队列,使用 tf.train.string_input_producer() 函数,该函数需要传入一个 list 。该函数有两个重要的参数,一是 num_epochs ,代表重复次数;二是 shuffle ,代表是否无序读入,True 表示无序,False表示有序。内存队列 tensorflow 会自动创建:
filename=['A.jpg','B.jpg','C.jpg']
tf.train_string_inpit_producer(filename,num_epochs=2,shuffle=True)
如下图,有图片A,B,C。将文件ABC无序的读入文件名队列中,重复两次
在读取完成后,系统会抛出 ‘OutOfRange’ 异常。
在创建完队列后,系统还未真正的将文件加载进去,需要使用函数 tf.train.start_queue_runners() 开始加载。
利用tensorflow 数据加载机制将cifar10二进制数据转为图像
这里需要三个库,tensorflow, os,scipy.misc.toimage。os用来对文件操作,scipy.misc.toimage用于保存图片
import tensorflow as tf
import os
from scipy.misc import toimage
def read_cifar10(filename_queue):
'''
按字节从 filename_queue 读取数据
:param filename_queue: tensorflow 的文件夹队列
:return: 图像列表
'''
class CIFAR10Record(object):
pass
result = CIFAR10Record()
label_bytes = 1
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),
[result.depth, result.height, result.width])
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
def input_origin(path):
'''
:param path: cifar10文件夹路径
:return: 图片队列
'''
filename = [os.path.join(path, 'data_batch_%d.bin' % i) for i in range(1, 6)]
for f in filename:
if not tf.gfile.Exists(f):
raise ValueError('fail')
filename_queue = tf.train.string_input_producer(filename)
read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
return reshaped_image
if __name__ == '__main__':
with tf.Session() as sess:
path = './cifar-10'
reshaped_image = input_origin(path)
threads = tf.train.start_queue_runners(sess=sess)
sess.run(tf.global_variables_initializer())
if not os.path.exists('./raw/'):
os.makedirs('./raw')
for i in range(30):
image_array = sess.run(reshaped_image)
toimage(image_array).save('./raw/%d.jpg' % i)