Tensorflow 是 Google 开源的一套机器学习框架, 支持 GPU,CPU,Android 等多种计算平台. 本文将介绍在 Tensorflow 在 Android 上的使用.
Android 使用 Tensorflow 框架需要引入两个文件 libtensorflow_inference.so,libandroid_tensorflow_inference_java.jar. 这两个文件可以使用官方预编译的文件. 如果预编译的 so 不满足要求 (比如不支持训练模型中的某些操作符运算), 也可以自己通过 bazel 编译生成这两个文件.
将 libandroid_tensorflow_inference_java.jar 放在 App 下的 libs 目录下, so 文件命名为 libtensorflow_jni.so 放在 src/main/jniLibs 目录下对应的 ABI 文件夹下. 目录结构如下:
Android 目录结构
同时在 App 的 build.gradle 中的 dependencies 模块下添加如下配置:
- dependencies {
- ...
- compile files('libs/libandroid_tensorflow_inference_java.jar')
- ...
- }
使用 tensorflow 框架进行机器学习分为四个步骤:
构造神经网络
训练神经网络模型
将训练好的模型输出为 pb 文件
ndroid 上加载 pb 模型进行计算
前三步是模型的构造, 我们通过 python 实现, 下面给出了一个二分类的简单模型的构造过程, 首先是训练过程:
- # -*-coding:utf-8 -*-
- from __future__ import print_function
- import os
- import tensorflow as tf
- from numpy.random import RandomState
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- """
- 训练模型
- """
- def train():
- # 定义训练数据集 batch 大小为 8
- batch_size = 8
- # 定义神经网络参数, 参数体现出神经网络结构, 一个输入层, 一个输出层, 一个隐藏层
- w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val")
- w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val")
- # 定义输入输出格式
- x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input')
- y_ = tf.placeholder(tf.float32, shape=(None, 1))
- # 定义神经网络前向传播过程
- a = tf.matmul(x, w1)
- y = tf.matmul(a, w2, name="cal_node")
- # 定义交叉熵和反向传播算法
- cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
- train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)
- # 生成随机训练集
- rdm = RandomState(1)
- dataset_size = 128
- # 定义映射关系
- X = rdm.rand(dataset_size, 2)
- Y = [[int(x1 + x2 <1)] for (x1, x2) in X]
- with tf.Session() as sess:
- # 初始化所有参数
- init_op = tf.global_variables_initializer()
- sess.run(init_op)
- # print sess.run(w1)
- # print sess.run(w2)
- STEPS = 500
- for i in range(STEPS):
- start = (i * batch_size) % dataset_size
- end = min(start + batch_size, dataset_size)
- # 训练神经网络, 更新神经网络参数
- sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
- if i % 100 == 0:
- total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
- print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))
- print(sess.run(w1))
- print(sess.run(w2))
- # 保存 check point
- saver = tf.train.Saver(tf.trainable_variables())
- saver.save(sess, './model/checpt')
上面的代码首先定义神经网络, 初始化训练数据, 进行 500 次训练过程, 并将训练结果 checkpoints 保存到 model 文件夹下, checkpoints 包含了训练模型得到的参数信息, 共生成四个相关的文件, 如下图:
由于 checkpoint 文件众多, 为了方便使用, 我们通过下面的代码将它们生成一个 pb 文件, 在 Android 上只需要这个 pb 文件即可使用这个训练好的模型:
- """
- 存储 pb 模型
- """
- def dump_graph_to_pb(pb_path):
- with tf.Session() as sess:
- check_point = tf.train.get_checkpoint_state("./model/")
- if check_point:
- saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta')
- saver.restore(sess, check_point.model_checkpoint_path)
- else:
- raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path))
- graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(","))
- with tf.gfile.GFile(pb_path, "wb") as f:
- f.write(graph_def.SerializeToString())
拿到生成的 pb 模型, 我们可以在 Android 上使用了. 将 pb 文件在这 main/assets 下:
接下来就可以载入 pb, 进行计算了:
- public class MainActivity extends AppCompatActivity {
- private Graph graph_;
- private Session session_;
- private AssetManager assetManager;
- private static ExecutorService executorService;
- private static Handler handler;
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
- setContentView(R.layout.activity_main);
- executorService = Executors.newFixedThreadPool(5);
- // 初始化 tensorflow
- initTensorFlow("outmodel.pb");
- // 使用 tensorflow 进行计算
- runTensorFlow();
- }
- ...
- }
通过如下方式载入 pb 模型, 初始化 tensorflow:
- private boolean initTensorFlow(String modelFile) {
- assetManager = getAssets();
- // 新建 Graph
- graph_ = new Graph();
- InputStream is = null;
- try {
- // 读取 Assets pb 文件
- is = assetManager.open(modelFile);
- } catch (IOException e) {
- e.printStackTrace();
- return false;
- }
- try {
- // 加载 pb 到 Graph
- TensorUtil.loadGraph(is, graph_);
- is.close();
- } catch (IOException e) {
- e.printStackTrace();
- return false;
- }
- // 初始化 session
- session_ = new Session(graph_);
- if (session_ == null) {
- return false;
- }
- return true;
- }
然后就可以使用 tensorflow API 进行运算了:
- private void runTensorFlow() {
- executorService.execute(generatePredictRunnable(handler));
- }
- private Runnable generatePredictRunnable(Handler handler) {
- return new Runnable() {
- @Override
- public void run() {
- float[][] input = new float[1][2];
- input[0][0] = 1;
- input[0][1] = 2;
- // 定义输入 tensor
- Tensor inputTensor = Tensor.create(input);
- // 指定输入, 输出节点, 运行并得到结果
- Tensor resultTensor = session_.runner()
- .feed("x_input", inputTensor)
- .fetch("cal_node")
- .run()
- .get(0);
- float[][] dst = new float[1][1];
- resultTensor.copyTo(dst);
- // 处理结果
- ArrayList<Float> resultList = new ArrayList<>();
- for (float val : dst[0]) {
- if (val != 0) {
- resultList.add(val);
- } else {
- break;
- }
- }
- }
- };
- }
上面就是通过 python 训练机器学习模型, 并在 Android 平台进行调用的完整流程.
原创作者: JackMeGo, 原文链接: https://www.jianshu.com/p/eef4ab014a12
来源: https://www.cnblogs.com/hejunlin/p/12507132.html