當前位置:
首頁 > 知識 > 如何利用TensorFlow.js部署簡單的AI版「你畫我猜」圖像識別應用

如何利用TensorFlow.js部署簡單的AI版「你畫我猜」圖像識別應用

選自Medium

作者:Zaid Alyafeai

機器之心編譯

參與:Geek AI、路

本文創建了一個簡單的工具來識別手繪圖像,並且輸出當前圖像的名稱。該應用無需安裝任何額外的插件,可直接在瀏覽器上運行。作者使用谷歌 Colab 來訓練模型,並使用 TensorFlow.js 將它部署到瀏覽器上。

代碼和 demo

demo 地址:https://zaidalyafeai.github.io/sketcher/

代碼地址:https://github.com/zaidalyafeai/zaidalyafeai.github.io/tree/master/sketcher

請通過以下鏈接在谷歌 Colab 上測試自己的 notebook:https://colab.research.google.com/github/zaidalyafeai/zaidalyafeai.github.io/blob/master/sketcher/Sketcher.ipynb

數據集

我們將使用卷積神經網路(CNN)來識別不同類型的手繪圖像。這個卷積神經網路將在 Quick Draw 數據集(https://github.com/googlecreativelab/quickdraw-dataset)上接受訓練。該數據集包含 345 個類別的大約 5 千萬張手繪圖像。

部分圖像類別

流程

我們將使用 Keras 框架在谷歌 Colab 免費提供的 GPU 上訓練模型,然後使用 TensorFlow.js 直接在瀏覽器上運行模型。我在 TensorFlow.js 上創建了一個教程(https://medium.com/tensorflow/a-gentle-introduction-to-tensorflow-js-dba2e5257702)。在繼續下面的工作之前,請務必先閱讀一下這個教程。下圖為該項目的處理流程:

流程

在 Colab 上進行訓練

谷歌 Colab 為我們提供了免費的 GPU 處理能力。你可以閱讀下面的教程(https://medium.com/deep-learning-turkey/google-colab-free-gpu-tutorial-e113627b9f5d)了解如何創建 notebook 和開始進行 GPU 編程。

導入

我們將使用以 TensorFlow 作為後端、Keras 作為前端的編程框架

載入數據

由於內存容量有限,我們不會使用所有類別的圖像進行訓練。我們僅使用數據集中的 100 個類別(https://raw.githubusercontent.com/zaidalyafeai/zaidalyafeai.github.io/master/sketcher/mini_classes.txt)。每個類別的數據可以在谷歌 Colab(https://console.cloud.google.com/storage/browser/quickdrawdataset/full/numpybitmap?pli=1)上以 NumPy 數組的形式獲得,數組的大小為 [N, 784],其中 N 為某類圖像的數量。我們首先下載這個數據集:

由於內存限制,我們在這裡將每類圖像僅僅載入 5000 張。我們還將留出其中的 20% 作為測試數據。

數據預處理

我們對數據進行預處理操作,為訓練模型做準備。該模型將使用規模為 [N, 28, 28, 1] 的批處理,並且輸出規模為 [N, 100] 的概率。

創建模型

我們將創建一個簡單的卷積神經網路。請注意,模型越簡單、參數越少越好。實際上,我們將把模型轉換到瀏覽器上然後再運行,並希望模型能在預測任務中快速運行。下面的模型包含 3 個卷積層和 2 個全連接層:

擬合、驗證及測試

在這之後我們對模型進行了 5 輪訓練,將訓練數據分成了 256 批輸入模型,並且分離出 10% 作為驗證集。

訓練結果如下圖所示:

測試準確率達到了 92.20% 的 top 5 準確率。

準備 WEB 格式的模型

在我們得到滿意的模型準確率後,我們將模型保存下來,以便進行下一步的轉換。

為轉換安裝 tensorflow.js:

接著我們對模型進行轉換:

這個步驟將創建一些權重文件和包含模型架構的 json 文件。

通過 zip 將模型進行壓縮,以便將其下載到本地機器上:

最後下載模型:

在瀏覽器上進行推斷

本節中,我們將展示如何載入模型並且進行推斷。假設我們有一個尺寸為 300*300 的畫布。在這裡,我們不會詳細介紹函數介面,而是將重點放在 TensorFlow.js 的部分。

載入模型

為了使用 TensorFlow.js,我們首先使用下面的腳本:

你的本地機器上需要有一台運行中的伺服器來託管權重文件。你可以在 GitHub 上創建一個 apache 伺服器或者託管網頁,就像我在我的項目中所做的那樣(https://github.com/zaidalyafeai/zaidalyafeai.github.io/tree/master/sketcher)。

接著,通過下面的代碼將模型載入到瀏覽器:

關鍵字 await 的意思是等待模型被瀏覽器載入。

預處理

在進行預測前,我們需要對數據進行預處理。首先從畫布中獲取圖像數據:

文章稍後將介紹 getMinBox()。dpi 變數被用於根據屏幕像素的密度對裁剪出的畫布進行拉伸。

我們將畫布當前的圖像數據轉化為一個張量,調整大小並進行歸一化處理:

我們使用 model.predict 進行預測,這將返回一個規模為「N, 100」的概率。

我們可以使用簡單的函數找到 top 5 概率。

提升準確率

請記住,我們的模型接受的輸入數據是規模為 [N, 28, 28, 1] 的張量。我們繪圖畫布的尺寸為 300*300,這可能是兩個手繪圖像的大小,或者用戶可以在上面繪製一個小圖像。最好只裁剪包含當前手繪圖像的方框。為了做到這一點,我們通過找到左上方和右下方的點來提取圍繞圖像的最小邊界框。

用手繪圖像進行測試

下圖顯示了一些第一次繪製的圖像以及準確率最高的類別。所有的手繪圖像都是我用滑鼠畫的,用筆繪製的話應該會得到更高的準確率。

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

------------------------------------------------

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

弱監督學習下的商品識別:CVPR 2018細粒度識別挑戰賽獲勝方案簡介
自動「腦補」3D環境!DeepMind最新Science論文生成查詢網路GQN

TAG:機器之心 |