tensorflow 中的 tensor 值的获取:
- import tensorflow as tf
- # 定义变量 a
- a=tf.Variable([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])
- # 定义索引
- indics=[[0,0,0],[0,1,1],[0,1,2]]
- # 把 a 中索引为 indics 的值取出
- b=tf.gather_nd(a,indics)
- # 初始化
- init=tf.global_variables_initializer()
- with tf.Session() as sess:
- #执行初始化
- sess.run(init)
- #打印结果
- print(a.eval())
- print(b.eval())