Overview 之前几次推送的全部例程, 使用的都是 tensorflow 预处理过的数据集, 直接载入即可例如:
然而实际中我们使用的通常不会是这种超级经典的数据集, 如果我们有一组图像存储在磁盘上面, 如何以 mini-batch 的形式把它们读取进来然后高效的送进网络训练? 这次推送我们首先用 tensorflow 最底层的 API 处理这个问题, 后面推送介绍高层 API 高层 API 是对底层的进一步封装, 用户可以不必关心过多细节不过了解一下比较底层的 API 还是有好处的当你有一组自己的数据的时候, 你需要经过以下两个步骤:(1)将全部数据写入一个后缀 .tfredords 的文件
这个步骤涉及读入 ->预处理 ->写入 tfrecords, 对你的数据是什么格式没有要求例如, 如果你手中是图像数据, 那用 opencv/PIL 等接口读入; 如果是 matlab 数据(mat 文件), 那可以用 h5py 协议读入, 等等不管如何读入, 最终都要写入到统一的 tfrecords 文件中, 以便用 tensorflow 提供的接口高效读取
(2)以 mini-batch 的形式从 tfrecords 中读取数据, 送到模型的 placeholder 中支持网络训练
实验设置 代码中使用的数据是存在磁盘中的 400 张 png 图像, 也传到了 github 上面, 存在 my_data 路径下面部分如下:
代码实现以下功能: 制备 tfrecords 形式的数据集, 然后再以 mini-batch 读入, 为了测试读入是否成功, 把读入的数据显示在 tensorboard 上面
制备 tfrecords 数据集 在上次推送中(Tensorboard), 大部分代码都是遵循 API 接口的固定模式写就可以, 这次也主要以这种方式进行, 而不过多讨论背后的理论细节两个辅助函数定义这俩辅助函数的目的完全是不想让后面的代码太冗长
读取图像 & 写入 tfrecords 文件
几点说明 (1) 读取图像文件的时候用到了 glob 和 opencv 两个包 glob 是将路径下全部文件名一次性存到一个 list 中, 方面后面逐个读取; opencv 则只是利用 imread 接口读取图像文件的 (2) 和一切文件操作一样, 向 tfrecords 文件中写入内容也需要建立一个 writer 对象, 创建这个对象的是函数 tf.python_io.TFRecordWriter(3)feature 是我们创建的一个字典对象, 这里面可以包含你想记录的任何信息在这里我们存入了三对键值 (key: value):image_raw(图像数据, 这个是核心内容),heigh(高),width(宽) 你也可以加入更多的信息, 例如, 通道数目, 文件名等等这些信息在后面读取数据的时候都可以一并读取出来比如: 在主程序中, 你需要用到图像的尺寸参数, 那么你可以将图像和尺寸参数一起读出
(4)注意数据格式图像数据本身是 8bit 的, 因此我们用前面定义的辅助函数 _bytes_feature_ 把数据转化成 tensorflow 要求的 tf.train.BytesList 格式存入实际中还会碰到图像本身是以 float 形式存储的, 代码就需要相应的变动, 这个下次推送再说
从 tfrecords 中载入 nimi-batch 定义函数: 读取一个样本
几点说明:(1)整个代码过程很烦杂, 因为是调用的底层 API, 不过都是固定写法, 其中的内部原理主页菌一知半解, 不敢在这里随便讲 (2) 特别注意这里这个字典对象的定义方式首先, 这里的三个 key 要和前面制备 tfrecords 时候一致; 其次, 注意数据格式, image_raw 是 8bit 存储的, 所以读取的时候限定 tf.string 类型, 同理, height 和 width 要限定 tf.int64
(3)如前文所说, 字典中存入的信息都可以通过 key 来读取, 上面的代码只读取了图像信息, 如果想获取 height 的值, 可以补充这样一句代码:
height = tf.decode_raw(features[height], tf.int64)
然后在函数返回值中把 height 也返回即可
(4)每一个样本是以一维的形式从数据流中抽取出来的, 所以需要 reshape 成原始尺寸
定义 mini-batch
用前面定义的 read_record 获取一个样本, 然后用 tf.train.shuffle_batch 来封装一个 mini-batchtf.train.shuffle_batch 会多次通过 read_record 抽取样本, 并且开辟一块内存空间建立队列(queue), 将样本洗牌打乱, 空间开辟越大, 数据混乱度会越高控制洗牌的参数是 capacity 和 min_after_dequeue, 官网文档中给出了这俩参数的取值建议, 我粘贴到了代码注释中注意: 从最开始介绍 tensorflow 的时候主页菌就在强调一个事情: 任何东西在用 Session 运行之前都是没有实际值的这里也不例外在主程序部分, 每一个 step 都要这么一句代码:
batch = sess.run(data_batch)
这个 batch 才是实际的数据, 是可以 feed 给 placeholder 的
主程序部分 我们的主程序是读取 mini-batch 然后用 tensorboard 显示
说明: 有四行代码必不可少 session 开头的两行: coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
session 结尾的两行
- coord.request_stop()
- coord.join(threads)
至于内部机理, 文档写的太模糊, 主页菌缺少计算机基础理论知识, 并没有看懂
总结 我相信你可能已经看晕了...... 这部分太过琐碎, 细节很多, 官方文档里面写的也很模糊, 对内部机理解释的不到位面对这种情况, 主页菌最初选择的方法就是, 亲自尝试, 用几乎一整天的时间摸索出了这一套代码的套路虽然对机理还是一知半解, 但是对代码思路十分清晰了, 在自己的项目中能够迅速撸出一套数据预处理的代码所以, 主页菌的建议就是, 亲自调通一套 demo!
下期预告
这次推送的数据是 8bit 的, 然而如果我想用 float 格式存储怎么办?(或者原始数据就是 float 格式的, 总不能截断成 8bit 来存储吧......)虽然这部分内容不多, 但是由于这次推送信息量够大了, 还是放到下次单独说吧艾伯特 (http://www.aibbt.com/) 国内第一家人工智能门户
本次推送对应的源码:
http://www.aibbt.com/a/19073.html
来源: http://www.bubuko.com/infodetail-2522637.html