tf.nn.embedding_lookup 记录
tf.nn.embedding_lookup 函数的用法主要是选取一个张量里面索引对应的元素. tf.nn.embedding_lookup(tensor, id):tensor 就是输入张量, id 就是张量对应的索引, 其他的参数不介绍.
例如:
- import tensorflow as tf;
- import numpy as np;
- c = np.random.random([10,1])
- b = tf.nn.embedding_lookup(c, [1, 3])
- with tf.Session() as sess:
- sess.run(tf.initialize_all_variables())
- print sess.run(b)
- print c
输出:
- [[ 0.77505197]
- [ 0.20635818]]
- [[ 0.23976515]
- [ 0.77505197]
- [ 0.08798201]
- [ 0.20635818]
- [ 0.37183035]
- [ 0.24753178]
- [ 0.17718483]
- [ 0.38533808]
- [ 0.93345168]
- [ 0.02634772]]
分析: 输出为张量的第一和第三个元素.
---------------------
作者: UESTC_C2_403
slade_sal https://www.jianshu.com/u/79b57248a6c3 关注
- import numpy as np
- import tensorflow as tf
- data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
- data = tf.convert_to_tensor(data)
- lk = [[0,1],[1,0],[0,0]]
- lookup_data = tf.nn.embedding_lookup(data,lk)
- init = tf.global_variables_initializer()
- In [76]: data.shape
- Out[76]: (3, 2, 1)
- In [77]: np.array(lk).shape
- Out[77]: (3, 2)
- In [78]: lookup_data
- Out[78]: <tf.Tensor 'embedding_lookup_8:0' shape=(3, 2, 2, 1) dtype=int64>
- In [79]: data
- Out[79]:
- array([[[2],
- [1]],
- [[3],
- [4]],
- [[6],
- [7]]])
- In [80]: lk
- Out[80]: [[0, 1], [1, 0], [0, 0]]
- # lk[0]也就是 [0,1] 对应着下面 sess.run(lookup_data)的结果恰好是把 data 中的[[2],[1]],[[3],[4]]
- In [81]: sess.run(lookup_data)
- Out[81]:
- array([[[[2],
- [1]],
- [[3],
- [4]]],
- [[[3],
- [4]],
- [[2],
- [1]]],
- [[[2],
- [1]],
- [[2],
- [1]]]])
来源: http://www.bubuko.com/infodetail-2943181.html