當前位置:
首頁 > 最新 > 將 TensorFlow 訓練好的模型遷移到 Android APP上

將 TensorFlow 訓練好的模型遷移到 Android APP上

本文原載於作者天澤的 CSDN 博客,AI 研習社獲其授權轉載。


1.寫在前面

最近在做一個數字手勢識別的APP(關於這個項目,我會再寫一篇博客仔細介紹,博客地址:一步步做一個數字手勢識別APP,源代碼已經開源在github上,地址:Chinese-number-gestures-recognition),要把在PC端訓練好的模型放到Android APP上,調研了下,谷歌發布了TensorFlow Lite可以把TensorFlow訓練好的模型遷移到Android APP上,百度也發布了移動端深度學習框架mobile-deep-learning(MDL),這個框架應該是paddlepaddle的手機版,具體的細節沒有了解過。因為對TensorFlow稍微熟悉些,因此就決定用TensorFlow來做。

關於在PC端如何處理數據及訓練模型,請參見博客:一步步做一個數字手勢識別APP,代碼已經開源在github上,上面有代碼的說明和APP演示。這篇博客只介紹如何把TensorFlow訓練好的模型遷移到Android Studio上進行APP的開發。


第一步,首先在pc端訓練模型的時候要模型保存為.pb模型,在保存的時候有一點非常非常重要,就是你待會再Android studio是使用這個模型用到哪個參數,那麼你在保存pb模型的時候就把給哪個參數一個名字,再保存。

否則,你在Android studio中很難拿出這個參數,因為TensorFlow Lite的fetch()函數是根據保存在pb模型中的名字去尋找這個參數的。(如果你已經訓練好了模型,並且沒有給參數名字,且你不想再訓練模型了,那麼你可以嘗試下面的方法去找到你需要使用的變數的默認名字,見下面的代碼):

#輸出保存的模型中參數名字及對應的值withtf.gfile.GFile("model_50_200_c3//./digital_gesture.pb","rb")asf: #讀取模型數據

graph_def =tf.GraphDef()

graph_def.ParseFromString(f.read()) #得到模型中的計算圖和數據withtf.Graph().as_default()asgraph: # 這裡的Graph()要有括弧,不然會報TypeError

tf.import_graph_def(graph_def, name="") #導入模型中的圖到現在這個新的計算圖中,不指定名字的話默認是 import

forop in graph.get_operations(): # 列印出圖中的節點信息

print(op.name, op.values())

這段代碼打出的變數的名字以及對應的值。

言歸正傳,通常情況該你應該保存參數的時候都給參數一個指定的名字,如下面這樣(通過name參數給變數指定名字),關於訓練CNN的完整代碼請參見下一篇博客或者github:

X=tf.placeholder(tf.float32, [None,64,64,3], name="input_x")

y=tf.placeholder(tf.float32, [None,11], name="input_y")

kp =tf.placeholder_with_default(1.0, shape=(), name="keep_prob")

lam =tf.placeholder(tf.float32, name="lamda")#中間略過若干代碼z_fc2 =tf.add(tf.matmul(z_fc1_drop, W_fc2),b_fc2, name="outlayer")

prob =tf.nn.softmax(z_fc2, name="probability")

pred =tf.argmax(prob,1, output_type="int32", name="predict")

1


第二步,開始把pb模型移植到Android Studio上,網上絕大部分資料都是說用bazel重新編譯模型生成依賴,這種方法難度太大。其實沒必須這樣做,TensorFlow Lite官方的例子中已經給我們展示了,我們其實只需要兩個文件:

libandroid_tensorflow_inference_java.jar 和 libtensorflow_inference.so。

這兩個文件我已經放到github上了,大家可以自行下載使用,下載地址:libandroid_tensorflow_inference_java.jar、libtensorflow_inference.so。

註:檢神說,直接用aar依賴也可以,這個我沒試過。。有興趣的可以試一下。

準備工作已經完畢,下面正式開始Android Studio中的配置。

首先把訓練好的pb模型放到Android項目中app/src/main/assets下,若不存在assets目錄,則自己新建一個。如圖所示:

其次,把剛剛下載的 libandroid_tensorflow_inference_java.jar 文件放到 app/libs 目下,把libtensorflow_inference.so 放到 app/libs/armeabi-v7a 目錄下,如下圖所示:

然後在app/build.gradle里進行如下配置:

在defaultConfig里添加

multiDexEnabledtrue

ndk {

abiFilters"armeabi-v7a"

}

在android里添加

sourceSets {

main {

jni.srcDirs = []

jniLibs.srcDirs = ["libs"]

}

}

如圖所示:

在dependencies中添加libandroid_tensorflow_inference_java.jar,即:

implementationfiles("libs/libandroid_tensorflow_inference_java.jar")

如圖所示:

至此,所有配置已經完成,下面是模型調用。


在要用到模型的地方,首先要載入libtensorflow_inference.so庫和初始化TensorFlowInferenceInterface對象,代碼為:

TensorFlowInferenceInterface inferenceInterface;static{//載入libtensorflow_inference.so庫文件

System.loadLibrary("tensorflow_inference");

Log.e("tensorflow","libtensorflow_inference.so庫載入成功");

}

Classifier(AssetManager assetManager,StringmodePath) {//初始化TensorFlowInferenceInterface對象

inferenceInterface =newTensorFlowInferenceInterface(assetManager,modePath);

Log.e("tf","TensoFlow模型文件載入成功");

}

如圖所示:

下面來多看一點東西,看看TensorFlow Lite里提供了哪幾個介面,官網地址:Here』s what a typical Inference Library sequence looks like on Android.

// Load the model from disk.

TensorFlowInferenceInterface inferenceInterface =

newTensorFlowInferenceInterface(assetManager, modelFilename);

// Copy the input data into TensorFlow.

inferenceInterface.feed(inputName, floatValues,1, inputSize, inputSize,3);

// Run the inference call.

inferenceInterface.run(outputNames, logStats);

// Copy the output Tensor back into the output array.

inferenceInterface.fetch(outputName, outputs);

下面就可以愉快地使用模型了。放一段我調用模型的代碼,以供大家參考:

publicArrayListpredict(Bitmap bitmap)

{

ArrayListlist=newArrayList();float[] inputdata = getPixels(bitmap);for(inti =; i

{

Log.d("matrix",inputdata[i] +"");

}

inferenceInterface.feed(inputName, inputdata,1, IMAGE_SIZE, IMAGE_SIZE,3);//運行模型,run的參數必須是String[]類型

String[] outputNames =newString[];

inferenceInterface.run(outputNames);//獲取結果

int[] labels =newint[1];

inferenceInterface.fetch(outputName,labels);intlabel = labels[];float[] prob =newfloat[11];

inferenceInterface.fetch(probabilityName, prob);// float[] outlayer = new float[11];// inferenceInterface.fetch(outlayerName, outlayer);// for(int i = 0; i

for(inti =; i

{

Log.d("matrix",prob[i] +"");

}

DecimalFormat df =newDecimalFormat("0.000000");floatlabel_prob = prob[label];//返回值

最後放一張做的數字手勢識別APP的效果,全部代碼,將會開源在github上,歡迎star。

再放一張碰運氣的識別結果:

Github 鏈接:

https://github.com/tz28/Chinese-number-gestures-recognition


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI研習社 的精彩文章:

Databricks 開源 MLflow 平台,解決機器學習開發四大難點
模型不收斂,訓練速度慢,如何才能改善 GAN 的性能?

TAG:AI研習社 |