tensorflow数据读取

  • 梦影无痕丶
  • 3 Minutes
  • October 6, 2018

简介

对于 文件名队列,使用 tf.train.string_input_producer() 函数,该函数需要传入一个 list 。该函数有两个重要的参数,一是 num_epochs ,代表重复次数;二是 shuffle ,代表是否无序读入,True 表示无序,False表示有序。内存队列 tensorflow 会自动创建:

filename=['A.jpg','B.jpg','C.jpg']
tf.train_string_inpit_producer(filename,num_epochs=2,shuffle=True)

如下图,有图片A,B,C。将文件ABC无序的读入文件名队列中,重复两次

在读取完成后,系统会抛出 ‘OutOfRange’ 异常。

在创建完队列后,系统还未真正的将文件加载进去,需要使用函数 tf.train.start_queue_runners() 开始加载。

利用tensorflow 数据加载机制将cifar10二进制数据转为图像

这里需要三个库,tensorflow, os,scipy.misc.toimage。os用来对文件操作,scipy.misc.toimage用于保存图片

import tensorflow as tf
import os
from scipy.misc import toimage

def read_cifar10(filename_queue):
    '''
    按字节从 filename_queue 读取数据
    :param filename_queue: tensorflow 的文件夹队列
    :return: 图像列表
    '''

    class CIFAR10Record(object):
        pass

    result = CIFAR10Record()

    label_bytes = 1
    result.height = 32
    result.width = 32
    result.depth = 3
    image_bytes = result.height * result.width * result.depth
    record_bytes = label_bytes + image_bytes
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    result.key, value = reader.read(filename_queue)
    record_bytes = tf.decode_raw(value, tf.uint8)
    result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
    depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),
                             [result.depth, result.height, result.width])
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])
    return result

def input_origin(path):
    '''

    :param path: cifar10文件夹路径
    :return: 图片队列
    '''

    filename = [os.path.join(path, 'data_batch_%d.bin' % i) for i in range(1, 6)]
    for f in filename:
        if not tf.gfile.Exists(f):
            raise ValueError('fail')
    filename_queue = tf.train.string_input_producer(filename)
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    return reshaped_image


if __name__ == '__main__':
    with tf.Session() as sess:
        path = './cifar-10'
        reshaped_image = input_origin(path)
        threads = tf.train.start_queue_runners(sess=sess)
        sess.run(tf.global_variables_initializer())
        if not os.path.exists('./raw/'):
            os.makedirs('./raw')
        for i in range(30):
            image_array = sess.run(reshaped_image)
            toimage(image_array).save('./raw/%d.jpg' % i)