Tensorflow 官方提供的 Tensorboard 可以可视化神经网络结构图, 但是说实话, 我几乎从来不用. 主要是因为 Tensorboard 中查看到的图结构太混乱了, 包含了网络中所有的计算节点 (读取数据节点, 网络节点, loss 计算节点等等). 更可怕的是, 如果一个计算节点是由多个基础计算(如加减乘除等) 构成, 那么在 Tensorboard 中会将基础计算节点显示而不是作为一个整体显示(典型的如 Squeeze 计算节点). 最近为了排查网络结构 BUG 花费一周时间, 因此, 狠下心来决定自己写一个工具, 将 Tensorflow 中的图以最简单的方式显示最关键的网络结构.
1 Tensor 对象与 Operation 对象
Tensorflow 中, Tensor 对象主要用于存储数据如常量和变量(训练参数),Operation 对象是计算节点, 如卷积计算, 反卷积计算, ReLU 等等. 每一个 Operation 对象均有输入和输出 Tensor, 同理, 每个 Tensor 对象均有对应生成该 Tensor 的 Operation 对象和使用该 Tensor 对象作为输入的 Operation 对象. Tensor 和 Operation 对象内均有相关属性和函数来获取其关联的 Operation 和 Tensor 对象, 相关属性如下所示.
Tensor 对象的 op 属性指向生成该 Tensor 的 Operation 对象.
Tensor 对象的 consumers()函数获取使用该 Tensor 对象作为输入的 Operation 对象.
Operation 对象的 inputs 属性指向该计算节点的输入 Tensor 对象.
Operation 对象的 outputs 属性执行该计算节点的输出 Tensor 对象.
如下图所示的网络结构中, 调用 Tensor_2 对象的 consumers()函数, 返回的是[op_1,op_2].Tensor_3 的 op 属性指向的是 op_1.op_1 的 inputs 属性指向的是[Tensor_1,Tensor_2],op_1 的 output 属性指向的是[Tensor_3].
Tensor 与 Operation
有了 Tensor 与 Operation 对应在图中的关联关系, 就可以将网络结构给画出来.
2 提取 pb 文件中的网络结构图
pb 文件是将模型参数固化到图文件中, 并合并了一些基础计算和删除了反向传播相关计算得到的 protobuf 协议文件. 如果读者还不懂如何将 CKPT 模型文件转 pb 文件, 请参考我另一篇文章《 Tensorflow MobileNet 移植到 Android》的第 1 节部分. 有了 pb 模型文件后, 接下来是加载模型, 加载 pb 模型示例代码如下所示.
- def read_graph_from_pb(tf_model_path ,input_names,output_name):
- with open(tf_model_path, 'rb') as f:
- serialized = f.read()
- tf.reset_default_graph()
- gdef = tf.GraphDef()
- gdef.ParseFromString(serialized)
- with tf.Graph().as_default() as g:
- tf.import_graph_def(gdef, name='')
- with tf.Session(graph=g) as sess:
- OPS=get_ops_from_pb(g,input_names,output_name)
- return OPS
其中, 倒数第 2 行调用到的函数 get_ops_from_pb()用于获取网络结构图中指定输入节点和指定输出节点之间的计算节点. 之所以要指定输入和输出, 是为了将输入之前的计算节点 (如加载数据队列等相关计算节点) 和输出之后的计算节点 (如计算 loss 等相关计算节点) 去除, 免得碍眼. 函数 get_ops_from_pb()实现代码如下.
- def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
- if save_ori_network:
- with open('ori_network.txt','w+') as w:
- OPS=graph.get_operations()
- for op in OPS:
- txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
- w.write(txt+'\n')
- inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
- output_tf =graph.get_tensor_by_name(output_name)
- OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] )
- with open('network.txt','w+') as w:
- for op in OPS:
- txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
- w.write(txt+'\n')
- OPS = sort_ops(OPS)
- OPS = merge_layers(OPS)
- return OPS
在裁剪网络结构 (即只保留 input_names 和 output_name 之间节点) 之前, 先将原始的网络结构写入到 ori_network.txt 中, 文件中, 每一行写入: 输入 Tensor---->op---->输出 Tensor. 接下来调用函数 get_ops_from_inputs_outputs 获取指定节点之间的节点. 并调用 sort_ops 函数对所有的节点排序, 以保证被依赖的节点总是出现在相关节点之前. 最后调用 merge_layers 函数, 将一些可以合并的计算合并成一个独立的节点, 例如, Squeeze 计算相关节点合并成一个单独的 Squeeze 节点, 又如 const-->identity 两个计算节点可以直接忽略(即删除).
注意: 篇幅有限, 这里不再将函数 get_ops_from_inputs_outputs,sort_ops,merge_layers 贴出, 相关代码请前往文尾提供的源码地址中阅读.
3 绘制网络结构
考虑到 SVG 绘制图形的简单易用优点, 将排好序的网络计算节点和相关 Tensor 对象数据以 JavaScript 字符串的形式写入到 html 中, 使用 < line > 标签绘制箭头, 使用 < rect > 标签绘制矩形, 使用 < ellipse > 标签绘制椭圆, 使用 < text > 标签显示文字. 绘制类似于如下所示图像
绘制网络结构示例
注意: 篇幅有限, 这里不再介绍 JavaScript 代码解析模型结构和 SVG 显示相关的原理, 相关代码请前往文尾提供的源码地址中阅读.
4 测试模型显示
以《MobileNet V1 官方预训练模型的使用》文中介绍的 MobileNet V1 网络结构为例, 下载 MobileNet_v1_1.0_192 文件并压缩后, 得到 mobilenet_v1_1.0_192_frozen.pb 文件. 我们还需要知道 mobilenet_v1_1.0_192_frozen.pb 模型对应的输入和输出 Tensor 对象的名称, 好在 MobileNet_v1_1.0_192 压缩包中包含文件 mobilenet_v1_1.0_192_info.txt. 通过该文件可知, 输入 Tensor 的名称为: input:0, 输出 Tensor 名称为: MobilenetV1/Predictions/Reshape_1:0. 有了这些信息后, 调用函数 read_graph_from_pb 得到静态图的节点列表对象 ops, 调用函数 gen_graph(ops,"save/path/graph.html")后, 在目录 save/path 中得到 graph.HTML 文件, 打开 graph.HTML 后, 显示结果如下.
显示网络结构分两种模式: 合并模式和展开模式, 分别如下图所示.
合并模式网络结构
截取的展开模式网络结构
5 源码地址
https://github.com/huachao1001/CNNGraph
来源: https://www.qcloud.com/developer/article/1361316