前言
保存 模型有 2 种方法.
方法
1. 使用 TensorFlow 模型保存函数
- save = tf.train.Saver()
- ......
- saver.save(sess,"checkpoint/model.ckpt",global_step=step)*
得到 3 个结果
- model.ckpt-129220.data-00000-of-00001# 保存了模型的所有变量的值.
- model.ckpt-129220.index
- model.ckpt-129220.meta # 保存了 graph 结构, 包括 GraphDef, SaverDef 等. 存在时, 可以不在文件中定义模型, 也可以运行
再将这 3 个文件保存为. pd 文件
- import tensorflow as tf
- import deeplab_model
- def export_graph(model, checkpoint_dir, model_name):
- ...
- model: the defined model
- checkpoint_dir: the dir of three files
- model_name: the name of .pb
- ...
- graph = tf.Graph()
- with graph.as_default():
- ### 输入占位符
- input_img = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
- labels = tf.zeros([1, 512, 512,1])
- labels = tf.to_int32(tf.image.convert_image_dtype(labels, dtype=tf.uint8))
- ### 需要输出的 Tensor
- output = model.deeplabv3_plus_model_fn(
- input_img,
- labels,
- tf.estimator.ModeKeys.EVAL,
- params={
- 'output_stride': 16,
- 'batch_size': 1, # Batch size must be 1 because the images' size may differ
- 'base_architecture': 'resnet_v2_50',
- 'pre_trained_model': None,
- 'batch_norm_decay': None,
- 'num_classes': 2,
- 'freeze_batch_norm': True
- }).predictions['classes']
- ### 给输出的 tensor 命名
- output = tf.identity(output, name='output_label')
- restore_saver = tf.train.Saver()
- with tf.Session(graph=graph) as sess:
- ### 初始化变量
- sess.run(tf.global_variables_initializer())
- ### load the model
- restore_saver.restore(sess, checkpoint_dir)
- output_graph_def = tf.graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), [output.op.name])
- ### 将图写成. pb 文件
- tf.train.write_graph(output_graph_def, 'pretrained', model_name, as_text=False)
- ### 调用函数, 生成. pd 文件
- export_graph(deeplab_model, 'model/model.ckpt-133958', 'model.pd')
- ### 读取
- import tensorflow as tf
- import os
- def inference():
- with tf.gfile.FastGFile('pretrained/model.pd', 'rb') as model_file:
- graph = tf.Graph()
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(model_file.read())
- [output_image] = tf.import_graph_def(graph_def,
- input_map={'input_image': images},
- return_elements=['output_label:0'],
- name='output')
- sess = tf.Session()
- label = sess.run(output_image)
- return label
- labels = inference()
2. 直接保存
- import tensorflow as tf
- from tensorflow.python.framework import graph_util
- var1 = tf.Variable(1.0, dtype=tf.float32, name='v1')
- var2 = tf.Variable(2.0, dtype=tf.float32, name='v2')
- var3 = tf.Variable(2.0, dtype=tf.float32, name='v3')
- x = tf.placeholder(dtype=tf.float32, shape=None, name='x')
- x2 = tf.placeholder(dtype=tf.float32, shape=None, name='x2')
- addop = tf.add(x, x2, name='add')
- addop2 = tf.add(var1, var2, name='add2')
- addop3 = tf.add(var3, var2, name='add3')
- initop = tf.global_variables_initializer()
- model_path = './Test/model.pb'
- with tf.Session() as sess:
- sess.run(initop)
- print(sess.run(addop, feed_dict={x: 12, x2: 23}))
- output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['add', 'add2', 'add3'])
- # 将计算图写入到模型文件中
- model_f = tf.gfile.FastGFile(model_path, mode="wb")
- model_f.write(output_graph_def.SerializeToString())
- #### 读取代码:
- import tensorflow as tf
- with tf.Session() as sess:
- model_f = tf.gfile.FastGFile("./Test/model.pb", mode='rb')
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(model_f.read())
- c = tf.import_graph_def(graph_def, return_elements=["add2:0"])
- c2 = tf.import_graph_def(graph_def, return_elements=["add3:0"])
- x, x2, c3 = tf.import_graph_def(graph_def, return_elements=["x:0", "x2:0", "add:0"])
- print(sess.run(c))
- print(sess.run(c2))
- print(sess.run(c3, feed_dict={x: 23, x2: 2}))
来源: http://www.bubuko.com/infodetail-3364858.html