今天在写 NCF 代码的时候, 发现网络上的代码有一种新的数据读取方式, 这里将对应的片段剪出来给大家分享下.
NCF 的文章参考: https://www.jianshu.com/p/6173dbde4f53
原始数据
我们的原始数据保存在 npy 文件中, 是一个字典类型, 有三个 key, 分别是 user,item 和 label:
- data = np.load('data/test_data.npy').item()
- print(type(data))
- #output
- <class 'dict'>
构建 tf 的 Dataset
使用 tf.data.Dataset.from_tensor_slices 方法, 将我们的数据变成 tensorflow 的 DataSet:
- dataset = tf.data.Dataset.from_tensor_slices(data)
- print(type(dataset))
- #output
- <class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>
进一步, 将我们的 Dataset 变成一个 BatchDataset, 这样的话, 在迭代数据的时候, 就可以一次返回一个 batch 大小的数据:
- dataset = dataset.shuffle(1000).batch(100)
- print(type(dataset))
- #output
- <class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>
可以看到, 我们在变成 batch 之前使用了一个 shuffle 对数据进行打乱, 100 表示 buffersize, 即每取 1000 个打乱一次.
此时 dataset 有两个属性, 分别是 output_shapes 和 output_types, 我们将根据这两个属性来构造迭代器, 用于迭代数据.
- print(dataset.output_shapes)
- print(dataset.output_types)
- #output
- {'user': TensorShape([Dimension(None)]), 'item': TensorShape([Dimension(None)]), 'label': TensorShape([Dimension(None)])}
- {'user': tf.int32, 'item': tf.int32, 'label': tf.int32}
构造迭代器
我们使用上面提到的两个 dataset 的属性, 并使用 tf.data.Iterator.from_structure 方法来构造一个迭代器:
- iterator = tf.data.Iterator.from_structure(dataset.output_types,
- dataset.output_shapes)
迭代器需要初始化:
sess.run(iterator.make_initializer(dataset))
此时, 就可以使用 get_next(), 方法来源源不断的读取 batch 大小的数据了
- def getBatch():
- sample = iterator.get_next()
- print(sample)
- user = sample['user']
- item = sample['item']
- return user,item
使用迭代器的正确姿势
我们这里来计算返回的每个 batch 中, user 和 item 的平均值:
- users,items = getBatch()
- usersum = tf.reduce_mean(users,axis=-1)
- itemsum = tf.reduce_mean(items,axis=-1)
迭代器 iterator 只能往前遍历, 如果遍历完之后还调用 get_next() 的话, 会报 tf.errors.OutOfRangeError 错误, 因此需要使用 try-catch.
- try:
- while True:
- print(sess.run([usersum,itemsum]))
- except tf.errors.OutOfRangeError:
- print("outOfRange")
如果想要多次遍历数据的话, 初始化外面包裹一层循环即可:
- for i in range(2):
- sess.run(iterator.make_initializer(dataset))
- try:
- while True:
- print(sess.run([usersum,itemsum]))
- except tf.errors.OutOfRangeError:
- print("outOfRange")
完整代码
- import numpy as np
- import tensorflow as tf
- data = np.load('data/test_data.npy').item()
- print(type(data))
- dataset = tf.data.Dataset.from_tensor_slices(data)
- print(type(dataset))
- dataset = dataset.shuffle(10000).batch(100)
- print(type(dataset))
- print(dataset.output_shapes)
- print(dataset.output_types)
- iterator = tf.data.Iterator.from_structure(dataset.output_types,
- dataset.output_shapes)
- print(type(iterator))
- def getBatch():
- sample = iterator.get_next()
- print(sample)
- user = sample['user']
- item = sample['item']
- return user,item
- users,items = getBatch()
- usersum = tf.reduce_mean(users,axis=-1)
- itemsum = tf.reduce_mean(items,axis=-1)
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- for i in range(2):
- sess.run(iterator.make_initializer(dataset))
- try:
- while True:
- print(sess.run([usersum,itemsum]))
- except tf.errors.OutOfRangeError:
- print("outOfRange")
来源: https://juejin.im/entry/5b14b1cae51d4506b26e9141