前言
只有光头才能变强.
文本已收录至我的 GitHub 仓库, 欢迎 Star: https://github.com/ZhongFuCheng3y/3y
回顾前面:
从零开始学 TensorFlow[01 - 搭建环境, HelloWorld 篇]
什么是 TensorFlow?
众所周知, 要训练出一个模型, 首先我们得有数据. 我们第一个例子中, 直接使用 dataset 的 API 去加载 mnist 的数据.(minst 的数据要么我们是提前下载好, 放在对应的目录上, 要么就根据他给的 url 直接从网上下载).
一般来说, 我们使用 TensorFlow 是从 TFRecord 文件中读取数据的.
TFRecord 文件格式是一种面向记录的简单 二进制格式 , 很多 TensorFlow 应用采用此格式来训练数据
所以, 这篇文章来聊聊怎么 读取 TFRecord 文件的数据.
一, 入门对数据集的数据进行读和写
首先, 我们来体验一下怎么造一个 TFRecord 文件, 怎么从 TFRecord 文件中读取数据, 遍历 (消费) 这些数据.
1.1 造一个 TFRecord 文件
现在, 我们还没有 TFRecord 文件, 我们可以自己简单写一个:
- def write_sample_to_tfrecord():
- gmv_values = np.arange(10)
- click_values = np.arange(10)
- label_values = np.arange(10)
- with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:
- for _ in range(10):
- feature_internal = {
- "gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
- "click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
- "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
- }
- features_extern = tf.train.Features(feature=feature_internal)
- # 使用 tf.train.Example 将 features 编码数据封装成特定的 PB 协议格式
- # example = tf.train.Example(features=tf.train.Features(feature=features_extern))
- example = tf.train.Example(features=features_extern)
- # 将 example 数据系列化为字符串
- example_str = example.SerializeToString()
- # 将系列化为字符串的 example 数据写入协议缓冲区
- writer.write(example_str)
- if __name__ == '__main__':
- write_sample_to_tfrecord()
我相信大家代码应该是能够看得懂的, 其实就是分了几步:
生成 TFRecord Writer
tf.train.Feature 生成协议信息
使用 tf.train.Example 将 features 编码数据封装成特定的 PB 协议格式
将 example 数据系列化为字符串
将系列化为字符串的 example 数据写入协议缓冲区
参考资料:
https://zhuanlan.zhihu.com/p/31992460
ok, 现在我们就有了一个 TFRecord 文件啦.
1.2 读取 TFRecord 文件
tf.data.TFRecordDataset
demo 代码如下:
- import tensorflow as tf
- def read_tensorflow_tfrecord_files():
- # 定义消费缓冲区协议的 parser, 作为 dataset.map()方法中传入的 lambda:
- def _parse_function(single_sample):
- features = {
- "gmv": tf.FixedLenFeature([1], tf.float32),
- "click": tf.FixedLenFeature([1], tf.int64), # ()或者 [] 没啥影响
- "label": tf.FixedLenFeature([1], tf.int64)
- }
- parsed_features = tf.parse_single_example(single_sample, features=features)
- # 对 parsed 之后的值进行 cast.
- gmv = tf.cast(parsed_features["gmv"], tf.float64)
- click = tf.cast(parsed_features["click"], tf.float64)
- label = tf.cast(parsed_features["label"], tf.float64)
- return gmv, click, label
- # 开始定义 dataset 以及解析 tfrecord 格式
- filenames = tf.placeholder(tf.string, shape=[None])
- # 定义 dataset 和 一些列 trasformation method
- dataset = tf.data.TFRecordDataset(filenames)
- parsed_dataset = dataset.map(_parse_function) # 消费缓冲区需要定义在 dataset 的 map 函数中
- batchd_dataset = parsed_dataset.batch(3)
- # 创建 Iterator
- sample_iter = batchd_dataset.make_initializable_iterator()
- # 获取 next_sample
- gmv, click, label = sample_iter.get_next()
- training_filenames = [
- "/Users/zhongfucheng/data/fashin/demo.tfrecord"]
- with tf.Session() as session:
- # 初始化带参数的 Iterator
- session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
- # 读取文件
- print(session.run(gmv))
- if __name__ == '__main__':
- read_tensorflow_tfrecord_files()
无意外的话, 我们可以输出这样的结果:
[[0.] [1.] [2.]]
ok, 现在我们已经大概知道怎么写一个 TFRecord 文件, 以及怎么读取 TFRecord 文件的数据, 并且消费这些数据了.
二, epoch 和 batchSize 术语解释
我在学习 TensorFlow 翻阅资料时, 经常看到一些机器学习的术语, 由于自己没啥机器学习的基础, 所以很多时候看到一些专业名词就开始懵逼了.
2.1epoch
当一个完整的数据集通过了神经网络一次并且返回了一次, 这个过程称为一个 epoch.
这可能使我们跟 dataset.repeat() 方法联系起来, 这个方法可以使当前数据集 重复 一遍. 比如说, 原有的数据集是 [1,2,3,4,5] , 如果我调用 dataset.repeat(2) 的话, 那么我们的数据集就变成了 [1,2,3,4,5],[1,2,3,4,5]
所以会有个说法: 假设原先的数据是一个 epoch, 使用 repeat(5)就可以将之变成 5 个 epoch
2.2batchSize
一般来说我们的数据集都是比较大的, 无法一次性 将整个数据集的数据喂进神经网络中, 所以我们会将数据集分成好几个部分. 每次喂多少条样本进神经网络, 这个叫做 batchSize.
在 TensorFlow 也提供了方法给我们设置: dataset.batch() , 在 API 中是这样介绍 batchSize 的:
representing the number of consecutive elements of this dataset to combine in a single batch
我们一般在每次训练之前, 会将 整个数据集的顺序打乱 , 提高我们模型训练的效果. 这里我们用到的 API 是: dataset.shffle();
三, 再来聊聊 dataset
我从官网的介绍中截了一个 dataset 的方法图(部分):
dataset 的功能主要有以下三种:
创建 dataset 实例
通过文件创建(比如 TFRecord)
通过内存创建
对数据集的数据进行变换
比如上面的 batch(), 常见的
map(),flat_map(),zip(),repeat()
等等
文档中一般都有给出 例子 , 跑一下一般就知道对应的意思了.
创建迭代器, 遍历数据集的数据
3.1 聊聊迭代器
迭代器可以分为四种:
单次. 对数据集进行一次迭代, 不支持参数化
可初始化迭代
使用前需要进行初始化, 支持传入参数 . 面向的是同一个 DataSet
可重新初始化: 同一个 Iterator 从不同的 DataSet 中读取数据
DataSet 的对象具有相同的结构, 可以使用
tf.data.Iterator.from_structure
来进行初始化
问题: 每次 Iterator 切换时, 数据都从头开始打印了
可馈送(也是通过对象相同的结果来创建的迭代器)
可让您在 两个数据集之间切换 的可馈送迭代器
通过一个 string handler 来实现.
可馈送的 Iterator 在不同的 Iterator 切换的时候, 可以做到不从头开始 .
简单总结:
1, 单次 Iterator , 它最简单, 但无法重用, 无法处理数据集参数化的要求.
2, 可以初始化的 Iterator , 它可以满足 Dataset 重复加载数据, 满足了参数化要求.
3, 可重新初始化的 Iterator, 它可以对接不同的 Dataset, 也就是可以从不同的 Dataset 中读取数据.
4, 可馈送的 Iterator, 它可以通过 feeding 的方式, 让程序在运行时候选择正确的 Iterator, 它和可重新初始化的 Iterator 不同的地方就是它的数据在不同的 Iterator 切换时, 可以做到不重头开始读取数据 .
string handler(可馈送的 Iterator)这种方式是最常使用的, 我当时也写了一个 Demo 来使用了一下, 代码如下:
- def read_tensorflow_tfrecord_files():
- # 开始定义 dataset 以及解析 tfrecord 格式.
- train_filenames = tf.placeholder(tf.string, shape=[None])
- vali_filenames = tf.placeholder(tf.string, shape=[None])
- # 加载 train_dataset batch_inputs 这个方法每个人都不一样的, 这个方法我就不给了.
- train_dataset = batch_inputs([
- train_filenames], batch_size=5, type=False,
- num_epochs=2, num_preprocess_threads=3)
- # 加载 validation_dataset batch_inputs 这个方法每个人都不一样的, 这个方法我就不给了.
- validation_dataset = batch_inputs([vali_filenames
- ], batch_size=5, type=False,
- num_epochs=2, num_preprocess_threads=3)
- # 创建出 string_handler()的迭代器(通过相同数据结构的 dataset 来构建)
- handle = tf.placeholder(tf.string, shape=[])
- iterator = tf.data.Iterator.from_string_handle(
- handle, train_dataset.output_types, train_dataset.output_shapes)
- # 有了迭代器就可以调用 next 方法了.
- itemid = iterator.get_next()
- # 指定哪种具体的迭代器, 有单次迭代的, 有初始化的.
- training_iterator = train_dataset.make_initializable_iterator()
- validation_iterator = validation_dataset.make_initializable_iterator()
- # 定义出 placeholder 的值
- training_filenames = [
- "/Users/zhongfucheng/tfrecord_test/data01aa"]
- validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]
- with tf.Session() as sess:
- # 初始化迭代器
- training_handle = sess.run(training_iterator.string_handle())
- validation_handle = sess.run(validation_iterator.string_handle())
- for _ in range(2):
- sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
- print("this is training iterator ----")
- for _ in range(5):
- print(sess.run(itemid, feed_dict={handle: training_handle}))
- sess.run(validation_iterator.initializer,
- feed_dict={vali_filenames: validation_filenames})
- print("this is validation iterator")
- for _ in range(5):
- print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))
- if __name__ == '__main__':
- read_tensorflow_tfrecord_files()
参考资料:
3.2 dataset 参考资料
在翻阅资料时, 发现写得不错的一些博客:
https://www.jianshu.com/p/91803a119f18
最后
乐于输出 干货 的 Java 技术公众号: Java3y. 公众号内有 200 多篇 原创 技术文章, 海量视频资源, 精美脑图, 不妨来 关注 一下!
下一篇文章打算讲讲如何理解 axis~
觉得我的文章写得不错, 不妨点一下 赞 !
来源: http://www.tuicool.com/articles/MFzaeq3