在用 PMML 实现机器学习模型的跨平台上线中, 我们讨论了使用 PMML 文件来实现跨平台模型上线的方法, 这个方法当然也适用于 tensorflow 生成的模型, 但是由于 tensorflow 模型往往较大, 使用无法优化的 PMML 文件大多数时候很笨拙, 因此本文我们专门讨论下 tensorflow 机器学习模型的跨平台上线的方法.
1. tensorflow 模型的跨平台上线的备选方案
tensorflow 模型的跨平台上线的备选方案一般有三种: 即 PMML 方式, tensorflow serving 方式, 以及跨语言 API 方式.
PMML 方式的主要思路在上一篇以及讲过. 这里唯一的区别是转化生成 PMML 文件需要用一个 Java 库 https://github.com/jpmml/jpmml-tensorflow 来完成, 生成 PMML 文件后, 跨语言加载模型和其他 PMML 模型文件基本类似.
tensorflow serving 是 tensorflow 官方推荐的模型上线预测方式, 它需要一个专门的 tensorflow 服务器, 用来提供预测的 API 服务. 如果你的模型和对应的应用是比较大规模的, 那么使用 tensorflow serving 是比较好的使用方式. 但是它也有一个缺点, 就是比较笨重, 如果你要使用 tensorflow serving, 那么需要自己搭建 serving 集群并维护这个集群. 所以为了一个小的应用去做这个工作, 有时候会觉得麻烦.
跨语言 API 方式是本文要讨论的方式, 它会用 tensorflow 自己的 Python API 生成模型文件, 然后用 tensorflow 的客户端库比如 Java 或 C++ 库来做模型的在线预测. 下面我们会给一个生成生成模型文件并用 tensorflow Java API 来做在线预测的例子.
2. 训练模型并生成模型文件
我们这里给一个简单的逻辑回归并生成逻辑回归 tensorflow 模型文件的例子.
首先, 我们生成了一个 6 特征, 3 分类输出的 4000 个样本数据.
- import numpy as np
- import matplotlib.pyplot as plt
- %matplotlib inline
- from sklearn.datasets.samples_generator import make_classification
- import tensorflow as tf
- X1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0,
- n_clusters_per_class=1, n_classes=3)
接着我们构建 tensorflow 的数据流图, 这里要注意里面的两个名字, 第一个是输入 x 的名字 input, 第二个是输出 prediction_labels 的名字 output, 这里的这两个名字可以自己取, 但是后面会用到, 所以要保持一致.
- learning_rate = 0.01
- training_epochs = 600
- batch_size = 100
- x = tf.placeholder(tf.float32, [None, 6],name='input') # 6 features
- y = tf.placeholder(tf.float32, [None, 3]) # 3 classes
- W = tf.Variable(tf.zeros([6, 3]))
- b = tf.Variable(tf.zeros([3]))
- # softmax 回归
- pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax")
- cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
- optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
- prediction_labels = tf.argmax(pred, axis=1, name="output")
- init = tf.global_variables_initializer()
接着就是训练模型了, 代码比较简单, 毕竟只是一个演示:
- sess = tf.Session()
- sess.run(init)
- y2 = tf.one_hot(y1, 3)
- y2 = sess.run(y2)
- for epoch in range(training_epochs):
- _, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2})
- if (epoch+1) % 10 == 0:
- print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c))
- print ("优化完毕!")
- correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y2, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- acc = sess.run(accuracy, feed_dict={x: X1, y: y2})
- print (acc)
打印输出我这里就不写了, 大家可以自己去试一试. 接着就是关键的一步, 存模型文件了, 注意要用 convert_variables_to_constants 这个 API 来保存模型, 否则模型参数不会随着模型图一起存下来.
- graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
- tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)
至此, 我们的模型文件 rf.pb 已经被保存下来了, 下面就是要跨平台上线了.
3. 模型文件在 Java 平台上线
这里我们以 Java 平台的模型上线为例, C++ 的 API 上线我没有用过, 这里就不写了. 我们需要引入 tensorflow 的 java 库到我们工程的 maven 或者 gradle 文件. 这里给出 maven 的依赖如下, 版本可以根据实际情况选择一个较新的版本.
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- <version>1.7.0</version>
- </dependency>
接着就是代码了, 这个代码会比 JPMML 的要简单, 我给出了 4 个测试样本的预测例子如下, 一定要注意的是里面的 input 和 output 要和训练模型的时候对应的节点名字一致.
- import org.tensorflow.*;
- import org.tensorflow.Graph;
- import java.io.IOException;
- import java.nio.file.Files;
- import java.nio.file.Paths;
- /**
- * Created by 刘建平 pinard on 2018/7/1.
- */
- public class TFjavaDemo {
- public static void main(String args[]){
- byte[] graphDef = loadTensorflowModel("D:/rf.pb");
- float inputs[][] = new float[4][6];
- for(int i = 0; i<4; i++){
- for(int j =0; j< 6;j++){
- if(i<2) {
- inputs[i][j] = 2 * i - 5 * j - 6;
- }
- else{
- inputs[i][j] = 2 * i + 5 * j - 6;
- }
- }
- }
- Tensor<Float> input = covertArrayToTensor(inputs);
- Graph g = new Graph();
- g.importGraphDef(graphDef);
- Session s = new Session(g);
- Tensor result = s.runner().feed("input", input).fetch("output").run().get(0);
- long[] rshape = result.shape();
- int rs = (int) rshape[0];
- long realResult[] = new long[rs];
- result.copyTo(realResult);
- for(long a: realResult ) {
- System.out.println(a);
- }
- }
- static private byte[] loadTensorflowModel(String path){
- try {
- return Files.readAllBytes(Paths.get(path));
- } catch (IOException e) {
- e.printStackTrace();
- }
- return null;
- }
- static private Tensor<Float> covertArrayToTensor(float inputs[][]){
- return Tensors.create(inputs);
- }
- }
我的预测输出是 1,1,0,0, 供大家参考.
4. 一点小结
对于 tensorflow 来说, 模型上线一般选择 tensorflow serving 或者 client API 库来上线, 前者适合于较大的模型和应用场景, 后者则适合中小型的模型和应用场景. 因此算法工程师使用在产品之前需要做好选择和评估.
来源: https://www.cnblogs.com/pinard/p/9251296.html