Python+Android進行TensorFlow開發
- 2020 年 3 月 16 日
- 筆記
Tensorflow是Google開源的一套機器學習框架,支援GPU、CPU、Android等多種計算平台。本文將介紹在Tensorflow在Android上的使用。
Android使用Tensorflow框架需要引入兩個文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。這兩個文件可以使用官方預編譯的文件。如果預編譯的so不滿足要求(比如不支援訓練模型中的某些操作符運算),也可以自己通過bazel編譯生成這兩個文件。
將libandroid_tensorflow_inference_java.jar放在app下的libs目錄下,so文件命名為libtensorflow_jni.so放在src/main/jniLibs目錄下對應的ABI文件夾下。目錄結構如下:
Android目錄結構
同時在app的build.gradle中的dependencies模組下添加如下配置:
dependencies { ... compile files('libs/libandroid_tensorflow_inference_java.jar') ... }
使用tensorflow框架進行機器學習分為四個步驟:
-
構造神經網路
-
訓練神經網路模型
-
將訓練好的模型輸出為pb文件
-
ndroid上載入pb模型進行計算
前三步是模型的構造,我們通過python實現,下面給出了一個二分類的簡單模型的構造過程,首先是訓練過程:
# -*-coding:utf-8 -*- from __future__ import print_function import os import tensorflow as tf from numpy.random import RandomState os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' """ 訓練模型 """ def train(): # 定義訓練數據集batch大小為8 batch_size = 8 # 定義神經網路參數,參數體現出神經網路結構,一個輸入層,一個輸出層,一個隱藏層 w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val") w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val") # 定義輸入輸出格式 x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input') y_ = tf.placeholder(tf.float32, shape=(None, 1)) # 定義神經網路前向傳播過程 a = tf.matmul(x, w1) y = tf.matmul(a, w2, name="cal_node") # 定義交叉熵和反向傳播演算法 cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy) # 生成隨機訓練集 rdm = RandomState(1) dataset_size = 128 # 定義映射關係 X = rdm.rand(dataset_size, 2) Y = [[int(x1 + x2 < 1)] for (x1, x2) in X] with tf.Session() as sess: # 初始化所有參數 init_op = tf.global_variables_initializer() sess.run(init_op) # print sess.run(w1) # print sess.run(w2) STEPS = 500 for i in range(STEPS): start = (i * batch_size) % dataset_size end = min(start + batch_size, dataset_size) # 訓練神經網路,更新神經網路參數 sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]}) if i % 100 == 0: total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y}) print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy)) print(sess.run(w1)) print(sess.run(w2)) # 保存check point saver = tf.train.Saver(tf.trainable_variables()) saver.save(sess, './model/checpt')
上面的程式碼首先定義神經網路,初始化訓練數據,進行500次訓練過程,並將訓練結果checkpoints保存到model文件夾下,checkpoints包含了訓練模型得到的參數資訊,共生成四個相關的文件,如下圖:
由於checkpoint文件眾多,為了方便使用,我們通過下面的程式碼將它們生成一個pb文件,在android上只需要這個pb文件即可使用這個訓練好的模型:
""" 存儲pb模型 """ def dump_graph_to_pb(pb_path): with tf.Session() as sess: check_point = tf.train.get_checkpoint_state("./model/") if check_point: saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta') saver.restore(sess, check_point.model_checkpoint_path) else: raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path)) graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(",")) with tf.gfile.GFile(pb_path, "wb") as f: f.write(graph_def.SerializeToString())
拿到生成的pb模型,我們可以在android上使用了。將pb文件在這main/assets下:
接下來就可以載入pb,進行計算了:
public class MainActivity extends AppCompatActivity { private Graph graph_; private Session session_; private AssetManager assetManager; private static ExecutorService executorService; private static Handler handler; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); executorService = Executors.newFixedThreadPool(5); // 初始化tensorflow initTensorFlow("outmodel.pb"); // 使用tensorflow進行計算 runTensorFlow(); } ... }
通過如下方式載入pb模型,初始化tensorflow:
private boolean initTensorFlow(String modelFile) { assetManager = getAssets(); // 新建Graph graph_ = new Graph(); InputStream is = null; try { // 讀取Assets pb文件 is = assetManager.open(modelFile); } catch (IOException e) { e.printStackTrace(); return false; } try { // 載入pb到Graph TensorUtil.loadGraph(is, graph_); is.close(); } catch (IOException e) { e.printStackTrace(); return false; } // 初始化session session_ = new Session(graph_); if (session_ == null) { return false; } return true; }
然後就可以使用tensorflow API進行運算了:
private void runTensorFlow() { executorService.execute(generatePredictRunnable(handler)); } private Runnable generatePredictRunnable(Handler handler) { return new Runnable() { @Override public void run() { float[][] input = new float[1][2]; input[0][0] = 1; input[0][1] = 2; // 定義輸入tensor Tensor inputTensor = Tensor.create(input); // 指定輸入,輸出節點,運行並得到結果 Tensor resultTensor = session_.runner() .feed("x_input", inputTensor) .fetch("cal_node") .run() .get(0); float[][] dst = new float[1][1]; resultTensor.copyTo(dst); // 處理結果 ArrayList<Float> resultList = new ArrayList<>(); for (float val : dst[0]) { if (val != 0) { resultList.add(val); } else { break; } } } }; }
上面就是通過python訓練機器學習模型,並在android平台進行調用的完整流程。
原創作者:JackMeGo,原文鏈接:https://www.jianshu.com/p/eef4ab014a12
歡迎關注我的微信公眾號「碼農突圍」,分享Python、Java、大數據、機器學習、人工智慧等技術,關注碼農技術提升•職場突圍•思維躍遷,20萬+碼農成長充電第一站,陪有夢想的你一起成長。