基础概念
在 tensorflow 的官方文档是这样介绍 Dataset 数据对象的:
Dataset 可以用来表示输入管道元素集合 (张量的嵌套结构) 和 "逻辑计划" 对这些元素的转换操作. 在 Dataset 中元素可以是向量, 元组或字典等形式.
另外, Dataset 需要配合另外一个类 Iterator 进行使用, Iterator 对象是一个迭代器, 可以对 Dataset 中的元素进行迭代提取.
看个简单的示例:
- # 创建一个 Dataset 对象
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 创建一个迭代器
- iterator = dataset.make_one_shot_iterator()
- #get_next()函数可以帮助我们从迭代器中获取元素
- element = iterator.get_next()
- # 遍历迭代器, 获取所有元素
- with tf.Session() as sess:
- for i in range(9):
- print(sess.run(element))
以上打印结果为: 1 2 3 4 5 6 7 8 9
Dataset 方法
1.from_tensor_slices
from_tensor_slices 用于创建 dataset, 其元素是给定张量的切片的元素.
函数形式: from_tensor_slices(tensors)
参数 tensors: 张量的嵌套结构, 每个都在第 0 维中具有相同的大小.
具体例子
- # 创建切片形式的 dataset
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 创建一个迭代器
- iterator = dataset.make_one_shot_iterator()
- #get_next()函数可以帮助我们从迭代器中获取元素
- element = iterator.get_next()
- # 遍历迭代器, 获取所有元素
- with tf.Session() as sess:
- for i in range(3):
- print(sess.run(element))
以上代码运行结果: 1 2 3
2.from_tensors
创建一个 Dataset 包含给定张量的单个元素.
函数形式: from_tensors(tensors)
参数 tensors: 张量的嵌套结构.
具体例子
- dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])
- iterator = concat_dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(1):
- print(sess.run(element))
以上代码运行结果:[1,2,3,4,5,6,7,8,9]
即 from_tensors 是将 tensors 作为一个整体进行操纵, 而 from_tensor_slices 可以操纵 tensors 里面的元素.
3.from_generator
创建 Dataset 由其生成元素的元素 generator.
函数形式: from_generator(generator,output_types,output_shapes=None,args=None)
参数 generator: 一个可调用对象, 它返回支持该 iter()协议的对象 . 如果 args 未指定, generator 则不得参数; 否则它必须采取与有值一样多的参数 args.
参数 output_types:tf.DType 对应于由元素生成的元素的每个组件的对象的嵌套结构 generator.
参数 output_shapes:tf.TensorShape 对应于由元素生成的元素的每个组件的对象 的嵌套结构 generator
参数 args:tf.Tensor 将被计算并将 generator 作为 NumPy 数组参数传递的对象元组.
具体例子
- # 定义一个生成器
- def data_generator():
- dataset = np.array(range(9))
- for i in dataset:
- yield i
- # 接收生成器, 并生产 dataset 数据结构
- dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32))
- iterator = concat_dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(3):
- print(sess.run(element))
以上代码运行结果: 0 1 2
4.batch
batch 可以将数据集的连续元素合成批次.
函数形式: batch(batch_size,drop_remainder=False)
参数 batch_size: 表示要在单个批次中合并的此数据集的连续元素个数.
参数 drop_remainder: 表示在少于 batch_size 元素的情况下是否应删除最后一批 ; 默认是不删除.
具体例子:
- # 创建一个 Dataset 对象
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- '''合成批次'''
- dataset=dataset.batch(3)
- # 创建一个迭代器
- iterator = dataset.make_one_shot_iterator()
- #get_next()函数可以帮助我们从迭代器中获取元素
- element = iterator.get_next()
- # 遍历迭代器, 获取所有元素
- with tf.Session() as sess:
- for i in range(9):
- print(sess.run(element))
以上代码运行结果为:
- [1 2 3]
- [4 5 6]
- [7 8 9]
即把目标对象合成 3 个批次, 返回的对象是传入 Dataset 对象.
5.concatenate
concatenate 可以将两个 Dataset 对象进行合并或连接.
函数形式: concatenate(dataset)
参数 dataset: 表示需要传入的 dataset 对象.
具体例子:
- # 创建 dataset 对象
- dataset_a=tf.data.Dataset.from_tensor_slices([1,2,3])
- dataset_b=tf.data.Dataset.from_tensor_slices([4,5,6])
- # 合并 dataset
- concat_dataset=dataset_a.concatenate(dataset_b)
- iterator = concat_dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print(sess.run(element))
以上代码运行结果: 1 2 3 4 5 6
6.filter
filter 可以对传入的 dataset 数据进行条件过滤.
函数形式: filter(predicate)
参数 predicate: 条件过滤函数
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 对 dataset 内的数据进行条件过滤
- dataset=dataset.filter(lambda x:x>3)
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print(sess.run(element))
以上代码运行结果: 4 5 6 7 8 9
7.map
map 可以将 map_func 函数映射到数据集
函数形式: flat_map(map_func,num_parallel_calls=None)
参数 map_func: 映射函数
参数 num_parallel_calls: 表示要并行处理的数字元素. 如果未指定, 将按顺序处理元素.
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 进行 map 操作
- dataset=dataset.map(lambda x:x+1)
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print(sess.run(element))
以上代码运行结果: 2 3 4 5 6 7
8.flat_map
flat_map 可以将 map_func 函数映射到数据集(与 map 不同的是 flat_map 传入的数据必须是一个 dataset).
函数形式: flat_map(map_func)
参数 map_func: 映射函数
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 进行 flat_map 操作
- dataset=dataset.flat_map(lambda x:tf.data.Dataset.from_tensor_slices(x+[1]))
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print(sess.run(element))
以上代码运行结果: 2 3 4 5 6 7
9.make_one_shot_iterator
创建 Iterator 用于枚举此数据集的元素.(可自动初始化)
函数形式: make_one_shot_iterator()
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print(sess.run(element))
- 10.make_initializable_iterator
创建 Iterator 用于枚举此数据集的元素.(使用此函数前需先进行迭代器的初始化操作)
函数形式: make_initializable_iterator(shared_name=None)
参数 shared_name:(可选)如果非空, 则返回的迭代器将在给定名称下共享同一设备的多个会话(例如, 使用远程服务器时)
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- iterator = dataset.make_initializable_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- #对迭代器进行初始化操作
- sess.run(iterator.initializer)
- for i in range(5):
- print(sess.run(element))
- 11.padded_batch
将数据集的连续元素组合到填充批次中, 此转换将输入数据集的多个连续元素组合为单个元素.
函数形式: padded_batch(batch_size,padded_shapes,padding_values=None,drop_remainder=False)
参数 batch_size: 表示要在单个批次中合并的此数据集的连续元素数.
参数 padded_shapes: 嵌套结构 tf.TensorShape 或 tf.int64 类似矢量张量的对象, 表示在批处理之前应填充每个输入元素的相应组件的形状. 任何未知的尺寸 (例如, tf.Dimension(None) 在一个 tf.TensorShape 或 - 1 类似张量的物体中)将被填充到每个批次中该尺寸的最大尺寸.
参数 padding_values:(可选)标量形状的嵌套结构 tf.Tensor, 表示用于各个组件的填充值. 默认值 0 用于数字类型, 空字符串用于字符串类型.
参数 drop_remainder:(可选)一个 tf.bool 标量 tf.Tensor, 表示在少于 batch_size 元素的情况下是否应删除最后一批 ; 默认行为是不删除较小的批处理.
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- dataset=dataset.padded_batch(2,padded_shapes=[])
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print(sess.run(element))
以上代码运行结果:
- [1 2]
- [3 4]
- 12.repeat
重复此数据集 count 次数
函数形式: repeat(count=None)
参数 count:(可选)表示数据集应重复的次数. 默认行为 (如果 count 是 None 或 - 1) 是无限期重复的数据集.
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 无限次重复 dataset 数据集
- dataset=dataset.repeat()
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(30,35):
- print(sess.run(element))
以上代码运行结果: 1 2 3 4 5
13.shard
将 Dataset 分割成 num_shards 个子数据集. 这个函数在分布式训练中非常有用, 它允许每个设备读取唯一子集.
函数形式: shard( num_shards,index)
参数 num_shards: 表示并行运行的分片数.
参数 index: 表示工人索引.
14.shuffle
随机混洗数据集的元素.
函数形式: shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)
参数 buffer_size: 表示新数据集将从中采样的数据集中的元素数.
参数 seed:(可选)表示将用于创建分布的随机种子.
参数 reshuffle_each_iteration:(可选)一个布尔值, 如果为 true, 则表示每次迭代时都应对数据集进行伪随机重组.(默认为 True.)
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 随机混洗数据
- dataset=dataset.shuffle(3)
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(30,35):
- print(sess.run(element))
以上代码运行结果: 3 2 4
15.skip
生成一个跳过 count 元素的数据集.
函数形式: skip(count)
参数 count: 表示应跳过以形成新数据集的此数据集的元素数. 如果 count 大于此数据集的大小, 则新数据集将不包含任何元素. 如果 count 为 - 1, 则跳过整个数据集.
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])
- # 跳过前 5 个元素
- dataset=dataset.skip(5)
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(30,35):
- print(sess.run(element))
以上代码运行结果: 6 7 8
16.take
提取前 count 个元素形成性数据集
函数形式: take(count)
参数 count: 表示应该用于形成新数据集的此数据集的元素数. 如果 count 为 - 1, 或者 count 大于此数据集的大小, 则新数据集将包含此数据集的所有元素.
具体例子
- dataset = tf.data.Dataset.from_tensor_slices([1,2,2,3,4,5,6,7,8,9])
- # 提取前 5 个元素形成新数据
- dataset=dataset.take(5)
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(30,35):
- print(sess.run(element))
以上代码运行结果: 1 2 2
17.zip
将给定数据集压缩在一起
函数形式: zip(datasets)
参数 datesets: 数据集的嵌套结构.
具体例子
- dataset_a=tf.data.Dataset.from_tensor_slices([1,2,3])
- dataset_b=tf.data.Dataset.from_tensor_slices([2,6,8])
- zip_dataset=tf.data.Dataset.zip((dataset_a,dataset_b))
- iterator = dataset.make_one_shot_iterator()
- element = iterator.get_next()
- with tf.Session() as sess:
- for i in range(30,35):
- print(sess.run(element))
以上代码运行结果:
- (1, 2)
- (2, 6)
- (3, 8)
到这里 Dataset 中大部分方法 都在这里做了初步的解释, 当然这些方法的配合使用才能够在建模过程中发挥大作用.
来源: https://www.cnblogs.com/wkslearner/p/9484443.html