1. tf.train.Saver()
tf.train.Saver()是一个类, 提供了变量, 模型 (也称图 Graph) 的保存和恢复模型方法.
TensorFlow 是通过构造 Graph 的方式进行深度学习, 任何操作 (如卷积, 池化等) 都需要 operator, 保存和恢复操作也不例外.
在 tf.train.Saver()类初始化时, 用于保存和恢复的 save 和 restore operator 会被加入 Graph. 所以, 下列类初始化操作应在搭建 Graph 时完成.
saver = tf.train.Saver()
TensorFlow 的保存和恢复分为两种:
保存和恢复变量
保存和恢复模型
saver.save()保存模型
# 举例:
保存一个训练好的手写数据集识别模型
保存在当前路径的 net 文件夹中
- import os
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- #载入数据集
- mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
- #每个批次 100 张照片
- batch_size = 100
- #计算一个需要多少个批次
- n_batch = mnist.train.num_examples // batch_size
- #定义两个 placeholder
- x = tf.placeholder(tf.float32, [None, 784])
- y = tf.placeholder(tf.float32, [None, 10])
- #创建一个简单的神经网络, 输入层 784 个神经元, 输出层 10 个神经元
- W = tf.Variable(tf.zeros([784, 10]))
- b = tf.Variable(tf.zeros([10]))
- prediction = tf.nn.softmax(tf.matmul(x, W) + b)
- #代价函数
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
- #使用梯度下降法
- train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
- #初始化变量
- init = tf.global_variables_initializer()
- #结果存放在一个布尔型列表中
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- saver = tf.train.Saver()
- with tf.Session() as sess:
- sess.run(init)
- for epoch in range(11):
- for batch in range(n_batch):
- batch_xs, batch_ys = mnist.train.next_batch(batch_size)
- sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
- acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
- print('Iter =' + str(epoch) +', Testing Accuracy =' + str(acc))
- #保存模型
- saver.save(sess, 'net/my_net.ckpt')
- View Code
- # 保存路径中的文件为:
checkpoint: 保存当前网络状态的文件
- my_net.ckpt.data-00000-of-00001
- my_net.ckpt.index
my_net.ckpt.meta: 保存 Graph 结构的文件
- # 关于函数 saver.save(), 常用的参数就是前三个:
- save(
- sess, # 必需参数, Session 对象
- save_path, # 必需参数, 存储路径
- global_step=None, # 可以是 Tensor, Tensor name, 整型数
- latest_filename=None, # 协议缓冲文件名, 默认为'checkpoint', 不用管
- meta_graph_suffix='meta', # 图文件的后缀, 默认为'.meta', 不用管
- write_meta_graph=True, # 是否保存 Graph
- write_state=True, # 建议选择默认值 True
- strip_default_attrs=False # 是否跳过具有默认值的节点
saver.restore()加载已经训练好的模型
# 举例:
通过加载刚才保存的训练好的手写数据集识别模型进行手写数据集的识别
- import os
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
- batch_size = 100
- n_batch = mnist.train.num_examples // batch_size
- x = tf.placeholder(tf.float32, [None, 784])
- y = tf.placeholder(tf.float32, [None, 10])
- W = tf.Variable(tf.zeros([784, 10]))
- b = tf.Variable(tf.zeros([10]))
- prediction = tf.nn.softmax(tf.matmul(x, W) + b)
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
- train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
- init = tf.global_variables_initializer()
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- saver = tf.train.Saver()
- with tf.Session() as sess:
- sess.run(init)
- print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
- saver.restore(sess, 'net/my_net.ckpt')
- print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
- View Code
- # 执行结果:
- 0.098
- 0.9178
- # 直接得到的准确率相当低, 通过加载训练好的模型, 识别准确率大大提升.
2. 下载 google 图像识别网络 inception-v3 并查看结构
模型背景:
Inception(v3) 模型是 Google 训练好的最新一个图像识别模型, 我们可以利用它来对我们的图像进行识别.
下载地址:
https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
文件描述:
classify_image_graph_def.pb 文件就是训练好的 Inception-v3 模型.
imagenet_synset_to_human_label_map.txt 是类别文件, 包含人类标签和 uid 之间的映射的文件.
imagenet_2012_challenge_label_map_proto.pbtxt 是包含类号和 uid 之间的映射的文件.
代码实现
- import os
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- import tensorflow as tf
- import tarfile
- import requests
- #inception 模型下载地址
- inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
- #inception 模型存放地址
- inception_pretrain_model_dir = 'inception_model'
- if not os.path.exists(inception_pretrain_model_dir):
- os.makedirs(inception_pretrain_model_dir)
- #获取文件名, 以及文件路径
- filename = inception_pretrain_model_url.split('/')[-1]
- filepath = os.path.join(inception_pretrain_model_dir, filename)
- #下载模型
- if not os.path.exists(filepath):
- print('download:', filename)
- r = requests.get(inception_pretrain_model_url, stream=True)
- with open(filepath, 'wb') as f:
- for chunk in r.iter_content(chunk_size=1024):
- if chunk:
- f.write(chunk)
- print('finish:', filename)
- #解压文件
- tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)
- #模型结构存放文件
- log_dir = 'inception_log'
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
- #classify_image_graph_def.pb 为 google 训练好的模型
- inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
- with tf.Session() as sess:
- #创建一个图来存放 google 训练好的模型
- with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- tf.import_graph_def(graph_def, name='')
- #保存图的结构
- writer = tf.summary.FileWriter(log_dir, sess.graph)
- writer.close()
- View Code
- # 在下载过程中, 下的特别慢, 不知道是网络原因还是什么
- # 程序总卡着不动
- # 所以我就手动下载压缩包并进行解压
下载结果
3. 使用 inception-v3 做各种图像的识别
# 代码实现:
- import os
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- import tensorflow as tf
- import numpy as np
- import re
- from PIL import Image
- import matplotlib.pyplot as plt
- #这部分是对标签号和类别号文件进行一个预处理
- class NodeLookup(object):
- def __init__(self):
- label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
- uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
- self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
- def load(self, label_lookup_path, uid_lookup_path):
- #加载分类字符串 n******** 对应分类名称的文件
- proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
- uid_to_human={}
- #一行一行读取数据
- for line in proto_as_ascii_lines:
- #去掉换行符
- line = line.strip('\n')
- #按照'\t'进行分割
- parsed_items = line.split('\t')
- #获取分类编号
- uid = parsed_items[0]
- #获取分类名称
- human_string = parsed_items[1]
- #保存编号字符串 n******** 与分类名称的映射关系
- uid_to_human[uid] = human_string
- #加载分类字符串 n******** 对应分类编号 1-1000 的文件
- proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
- node_id_to_uid = {}
- for line in proto_as_ascii:
- if line.startswith('target_class:'):
- #获取分类编号 1-1000
- target_class = int(line.split(':')[1])
- if line.startswith('target_class_string:'):
- #获取编号字符串 nn********
- target_class_string = line.split(':')[1]
- # 保存分类编号 1-1000 与编号字符串 n******** 映射关系
- node_id_to_uid[target_class] = target_class_string[1:-2]
- # 建立分类编号 1-1000 对应分类名称的映射关系
- node_id_to_name = {}
- for key, val in node_id_to_uid.items():
- #获取分类名称
- name = uid_to_human[val]
- # 建立分类编号 1-1000 到分类名称的映射关系
- node_id_to_name[key] = name
- return node_id_to_name
- # 传入分类编号 1-1000 返回分类名称
- def id_to_string(self, node_id):
- if node_id not in self.node_lookup:
- return ''
- return self.node_lookup[node_id]
- #创建一个图来存放 google 训练好的模型
- with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- tf.import_graph_def(graph_def, name='')
- with tf.Session() as sess:
- softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
- #遍历目录
- for root, dirs, files in os.walk('images/'):
- for file in files:
- #载入图片
- image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()
- predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})# 图片格式是 jpg 格式
- predictions = np.squeeze(predictions)# 把结果转为 1 维数据
- #打印图片路径及名称
- image_path = os.path.join(root, file)
- print(image_path)
- # 显示图片
- img = Image.open(image_path)
- plt.imshow(img)
- plt.axis('off')
- plt.show()
- #排序
- top_k = predictions.argsort()[-5:][::-1]
- node_lookup = NodeLookup()
- for node_id in top_k:
- # 获取分类名称
- human_string = node_lookup.id_to_string(node_id)
- # 获取该分类的置信度
- score = predictions[node_id]
- print('%s(score = %.5f)' % (human_string, score))
- print()
- View Code
- # 执行结果:
- images/1.jpg
- giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca(score = 0.87265)
- badger(score = 0.00260)
- lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens(score = 0.00205)
- brown bear, bruin, Ursus arctos(score = 0.00102)
- ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus(score = 0.00099)
- images/2.jpg
- French bulldog(score = 0.94474)
- bull mastiff(score = 0.00559)
- pug, pug-dog(score = 0.00352)
- Staffordshire bullterrier, Staffordshire bull terrier(score = 0.00165)
- boxer(score = 0.00116)
- images/3.jpg
- zebra(score = 0.94011)
- tiger, Panthera tigris(score = 0.00080)
- pencil box, pencil case(score = 0.00066)
- hartebeest(score = 0.00059)
- tiger cat(score = 0.00042)
- images/4.jpg
- hare(score = 0.87019)
- wood rabbit, cottontail, cottontail rabbit(score = 0.04802)
- Angora, Angora rabbit(score = 0.00612)
- wallaby, brush kangaroo(score = 0.00181)
- fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.00056)
- images/5.jpg
- fox squirrel, eastern fox squirrel, Sciurus niger(score = 0.95047)
- marmot(score = 0.00265)
- mongoose(score = 0.00217)
- weasel(score = 0.00201)
- mink(score = 0.00199)
来源: https://www.cnblogs.com/guoruxin/p/10238018.html