前端AI之路:KerasJS初探
前端也可以搞 AI?今天就來看看 @周林 給我們帶來的深度神經網路的JS框架 Keras.js 的初體驗。
簡介
Keras是一款非常流行的深度學習模型開發框架,基於python,語法簡潔,封裝程度高,只需十幾行代碼就可以構建一個深度神經網路。
Keras.js是一個可以在瀏覽器中運行深度神經網路的JS框架,支持CPU,GPU計算。區別於Keras,Keras.js只能運行已經調試好的模型,無法進行模型訓練。
KerasJS開發流程如下,首先使用Keras開發訓練神經網路,將神經網路模型和參數導出為文件,KerasJS在瀏覽器端載入此文件,這樣才能進行預測。
模型
借鑒這篇文章,開發一個識別聖誕老人的神經網路。本文不涉及Keras的開發細節,感興趣的同學可以去原文查看。這裡直接給出python代碼。
數據
標註數據是AI模型的原料,數據搜集特別是圖片搜集是前端可以介入的一個環節。筆者基於React,開發了一款Chrome圖片批量下載插件GetThemAll,方便我們進行標記圖片搜集。
安裝好插件後,去谷歌圖片搜索「santa」, 使用插件標記不需要的圖片,然後下載到本地的santa文件夾,通過谷歌圖片可以搜集到400張聖誕老人的圖片。
接著我們再下載一些非聖誕老人的圖片,搜索「object」,同樣的使用GetThemAll插件下載大約400張圖片到本地的non_santa文件夾中。
除了訓練數據集,我們還需要一個測試數據集用來衡量模型的泛化能力。在本地新建一個test文件夾,把剛剛準備好的訓練集裡面的最後100張聖誕老人圖片移到test文件夾下的santa文件中,同樣的,移動100張非聖誕老人圖片到non_stanta文件中。這樣,你可以得到如下的本地圖片集:
有了標記數據,我們就可以進行模型訓練啦。具體的訓練過程請見pyton代碼,這裡直接給出訓練的結果,藍點表示訓練數據集準確率,藍線表示測試數據集準備率,模型有著明顯的High Variance問題,不過這個bug留給深度學習的專家們解決吧,這裡就假設這個模型可用。
遷移
上一步訓練出的模型keras_santa.h5(h5是文件後綴,和HTML5沒啥關係)不能直接給KerasJS使用,需要通過KerasJS提供的轉換工具轉換後,方可被KerasJS載入解析。
轉換後,得到了keras_santa.bin文件,20M左右,這個文件包含了神經網路結構和所有參數,可以被KerasJS載入。
KerasJS
通過上面的步驟,我們得到了一個訓練完成的CNN神經網路以及全部參數,這個網路結構和參數全部保存在keras_santa.bin文件中。接下來,我們只需要在瀏覽器中復原上面的神經網路,然後就可以開始做預測啦。
使用webpack配合React,搭建一套簡單的開發環境。做好了基礎工作,就可以開始第一步開發,載入神經網路模型文件keras_santa.bin:
使用上面的模型做預測前,需要將數據轉化成模型能夠接受的數據格式。這個聖誕老人網路需要數據輸入格式為(128,128,3),也即是圖片需要為128x128解析度,只能包含RGB三個分量。
藉助canvas,可以實現圖片解析度轉換:
注意preprocess方法,通過canvas獲取到的圖片資源包含了rgba四個維度,prepross返回這4個維度中的前3個維度,也即rgb,同時將數據標準化:
最後,使用上面返回的數據做預測
GIF
思考
可以看到,KerasJS在預測過程中,整個頁面無法響應用戶操作。這是因為神經網路計算過程中佔用了大量CPU資源,從而致使頁面卡頓。下一篇文章中,我們將介紹如何使用WebGL,將計算過程轉移到GPU,達到實現前端高性能計算的目的。
同時,模型參數文件體積超過20M。如何對模型文件進行壓縮,滿足生產級別可用的要求,也是前端同學可以深挖的一個方向。
相關資源
Image classification with Keras and deep learning, Adrain Rosebrock
GetThemAll, eeandrew
React Keras,eeandrew
※使用 Rust 加速前端監控
※Web 前端中的增強現實開發技術
TAG:前端外刊評論 |