如果您希望能有一種簡單、高效且靈活的方式把 TensorFlow 模型集成到 Flutter 應(yīng)用里,那請您一定不要錯(cuò)過我們今天介紹的這個(gè)全新插件tflite_flutter。這個(gè)插件的開發(fā)者是 Google Summer of Code (GSoC) 的一名實(shí)習(xí)生 Amish Garg。
tflite_flutter插件的核心特性:
插件提供了與 TFLite Java 和 Swift API 相似的 Dart API,所以其靈活性和在這些平臺(tái)上的效果是完全一樣的;
插件通過 dart:ffi 直接與 TensorFlow Lite C API 相綁定,所以它比其它平臺(tái)集成方式更加高效;
無需編寫特定平臺(tái)的代碼;
通過 NNAPI 提供加速支持,在 Android 上使用 GPU Delegate,在 iOS 上使用 Metal Delegate。
本文中,我們將使用 tflite_flutter 構(gòu)建一個(gè)文字分類 Flutter 應(yīng)用,帶您體驗(yàn) tflite_flutter 插件。首先從新建一個(gè) Flutter 項(xiàng)目text_classification_app開始。
初始化配置
Linux 和 Mac用戶
將 install.sh 拷貝到您應(yīng)用的根目錄,然后在根目錄執(zhí)行 sh install.sh,本例中就是目錄 text_classification_app/。
Windows 用戶
將 install.bat 文件拷貝到應(yīng)用根目錄,并在根目錄運(yùn)行批處理文件 install.bat,本例中就是目錄 text_classification_app/。
它會(huì)自動(dòng)從GitHub 倉庫的 Releases 里下載最新的二進(jìn)制資源,然后把它放到指定的目錄下。
請點(diǎn)擊到 README 文件里查看更多關(guān)于初始配置的信息。
tflite_flutter 的 GitHub 倉庫
https://github.com/am15h/tflite_flutter_plugin
獲取插件
在pubspec.yaml添加tflite_flutter: ^
最新版本情況參考插件的發(fā)布地址
https://pub.flutter-io.cn/packages/tflite_flutter
下載模型
要在移動(dòng)端上運(yùn)行 TensorFlow 訓(xùn)練模型,我們需要使用 .tflite 格式。如果需要了解如何將 TensorFlow 訓(xùn)練的模型轉(zhuǎn)換為 .tflite 格式,請參閱官方指南。
這里我們準(zhǔn)備使用 TensorFlow 官方站點(diǎn)上預(yù)訓(xùn)練的文字分類模型。
該預(yù)訓(xùn)練的模型可以預(yù)測當(dāng)前段落的情感是積極還是消極。它是基于來自 Mass 等人的 Large Movie Review Dataset v1.0數(shù)據(jù)集進(jìn)行訓(xùn)練的。數(shù)據(jù)集由基于 IMDB 電影評論所標(biāo)記的積極或消極標(biāo)簽組成,查看更多信息。
將 text_classification.tflite 和 text_classification_vocab.txt 文件拷貝到 text_classification_app/assets/ 目錄下。
在 pubspec.yaml 文件中添加 assets/。
assets: - assets/
現(xiàn)在萬事俱備,我們可以開始寫代碼了。
模型轉(zhuǎn)換器(Converter)的 Python API 指南
https://tensorflow.google.cn/lite/convert/python_api
預(yù)訓(xùn)練的文字分類模型(text_classification.tflite)
https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification.tflite
數(shù)據(jù)集(text_classification_vocab.txt)
https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification_vocab.txt
實(shí)現(xiàn)分類器
預(yù)處理
正如文字分類模型頁面里所提到的??梢园凑障旅娴牟襟E使用模型對段落進(jìn)行分類:
對段落文本進(jìn)行分詞,然后使用預(yù)定義的詞匯集將它轉(zhuǎn)換為一組詞匯 ID;
將生成的這組詞匯 ID 輸入 TensorFlow Lite 模型里;
從模型的輸出里獲取當(dāng)前段落是積極或者是消極的概率值。
我們首先寫一個(gè)方法對原始字符串進(jìn)行分詞,其中使用 text_classification_vocab.txt作為詞匯集。
在 lib/文件夾下創(chuàng)建一個(gè)新文件 classifier.dart。
這里先寫代碼加載 text_classification_vocab.txt 到字典里。
import 'package:flutter/services.dart'; class Classifier { final _vocabFile = 'text_classification_vocab.txt'; Map
△加載字典
現(xiàn)在我們來編寫一個(gè)函數(shù)對原始字符串進(jìn)行分詞。
import 'package:flutter/services.dart'; class Classifier { final _vocabFile = 'text_classification_vocab.txt'; // 單句的最大長度 final int _sentenceLen = 256; final String start = '> tokenizeInputText(String text) { // 使用空格進(jìn)行分詞 final toks = text.split(' '); // 創(chuàng)建一個(gè)列表,它的長度等于 _sentenceLen,并且使用
> return [vec]; } }
△分詞代碼
使用 tflite_flutter 進(jìn)行分析
這是本文的主體部分,這里我們會(huì)討論 tflite_flutter 插件的用途。
此處的分析指的是在設(shè)備上基于輸入的數(shù)據(jù),使用 TensorFlow Lite 模型的處理過程。要使用 TensorFlow Lite 模型進(jìn)行分析,需要通過解釋器來運(yùn)行它,了解更多。
創(chuàng)建解釋器,加載模型
tflite_flutter 提供了一個(gè)方法直接通過資源創(chuàng)建解釋器。
static Future
由于我們的模型在 assets/文件夾下,需要使用上面的方法來創(chuàng)建解析器。對于 InterpreterOptions 的相關(guān)說明,請參考這里。
import 'package:flutter/services.dart'; // 引入 tflite_flutter import 'package:tflite_flutter/tflite_flutter.dart'; class Classifier { // 模型文件的名稱 final _modelFile = 'text_classification.tflite'; // TensorFlow Lite 解釋器對象 Interpreter _interpreter; Classifier() { // 當(dāng)分類器初始化以后加載模型 _loadModel(); } void _loadModel() async { // 使用 Interpreter.fromAsset 創(chuàng)建解釋器 _interpreter = await Interpreter.fromAsset(_modelFile); print('Interpreter loaded successfully'); } }
△創(chuàng)建解釋器的代碼
如果您不希望將模型放在assets/目錄下,tflite_flutter 還提供了工廠構(gòu)造函數(shù)創(chuàng)建解釋器,更多信息。
我們開始進(jìn)行分析!
現(xiàn)在用下面方法啟動(dòng)分析:
void run(Object input, Object output);
注意這里的方法和 Java API 中的是一樣的。
Object input 和 Object output 必須是與 Input Tensor 和 Output Tensor 維度相同的列表。
要查看 input tensor 和 output tensor 的維度,可以使用如下代碼:
_interpreter.allocateTensors(); // 打印 input tensor 列表 print(_interpreter.getInputTensors()); // 打印 output tensor 列表 print(_interpreter.getOutputTensors());
在本例中 text_classification 模型的輸出如下:
InputTensorList: [Tensor{_tensor: Pointer
現(xiàn)在,我們實(shí)現(xiàn)分類方法,該方法返回值為 1 表示積極,返回值為 0 表示消極。
int classify(String rawText) { // tokenizeInputText 返回形狀為 [1, 256] 的 List> List
> input = tokenizeInputText(rawText); // [1,2] 形狀的輸出 var output = List
△用于分析的代碼
在 tflite_flutter 的 extension ListShape on List 下面定義了一些使用的擴(kuò)展:
// 將提供的列表進(jìn)行矩陣變形,輸入參數(shù)為元素總數(shù)并保持相等 // 用法:List(400).reshape([2,10,20]) // 返回 List
最終的 classifier.dart 應(yīng)該是這樣的:
import 'package:flutter/services.dart'; // 引入 tflite_flutter import 'package:tflite_flutter/tflite_flutter.dart'; class Classifier { // 模型文件的名稱 final _modelFile = 'text_classification.tflite'; final _vocabFile = 'text_classification_vocab.txt'; // 語句的最大長度 final int _sentenceLen = 256; final String start = '> List
> input = tokenizeInputText(rawText); //輸出形狀為 [1, 2] 的矩陣 var output = List
> tokenizeInputText(String text) { // 用空格分詞 final toks = text.split(' '); // 創(chuàng)建一個(gè)列表,它的長度等于 _sentenceLen,并且使用
> return [vec]; } }
現(xiàn)在,可以根據(jù)您的喜好實(shí)現(xiàn) UI 的代碼,分類器的用法比較簡單。
// 創(chuàng)建 Classifier 對象 Classifer _classifier = Classifier(); // 將目標(biāo)語句作為參數(shù),調(diào)用 classify 方法 _classifier.classify("I liked the movie"); // 返回 1 (積極的) _classifier.classify("I didn't liked the movie"); // 返回 0 (消極的)
△ 文字分類示例應(yīng)用
了解更多關(guān)于 tflite_flutter 插件的信息,請?jiān)L問 GitHub repo:am15h/tflite_flutter_plugin。
你問我答
問:tflite_flutter 和 tflite v1.0.5 有哪些區(qū)別?
tflite v1.0.5 側(cè)重于為特定用途的應(yīng)用場景提供高級特性,比如圖片分類、物體檢測等等。而新的 tflite_flutter 則提供了與 Java API 相同的特性和靈活性,而且可以用于任何 tflite 模型中,它還支持 delegate。
由于使用 dart:ffi (dart (ffi) C),tflite_flutter 非???(擁有低延時(shí))。而 tflite 使用平臺(tái)集成 (dart platform-channel (Java/Swift) JNI C)。
問:如何使用 tflite_flutter 創(chuàng)建圖片分類應(yīng)用?有沒有類似 TensorFlow Lite Android Support Library 的依賴包?
TensorFlow Lite Flutter Helper Library為處理和控制輸入及輸出的 TFLite 模型提供了易用的架構(gòu)。它的 API 設(shè)計(jì)和文檔與 TensorFlow Lite Android Support Library 是一樣的。更多信息請參考 TFLite Flutter Helper 的 GitHub 。
TFLite Flutter Helper 開發(fā)庫 GitHub 倉庫地址
https://github.com/am15h/tflite_flutter_helper
以上是本文的全部內(nèi)容,歡迎大家對 tflite_flutter 插件進(jìn)行反饋,請?jiān)?GitHub報(bào) bug 或提出功能需求。謝謝關(guān)注,感謝 Flutter 團(tuán)隊(duì)的 Michael Thomsen。
向 tflite_flutter 插件提出建議和反饋
https://github.com/am15h/tflite_flutter_plugin/issues
-
iOS
+關(guān)注
關(guān)注
8文章
3384瀏覽量
150306 -
插件
+關(guān)注
關(guān)注
0文章
319瀏覽量
22377 -
tensorflow
+關(guān)注
關(guān)注
13文章
328瀏覽量
60445
原文標(biāo)題:社區(qū)分享 | 在 Flutter 中使用 TensorFlow Lite 插件實(shí)現(xiàn)文字分類
文章出處:【微信號(hào):tensorflowers,微信公眾號(hào):Tensorflowers】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論