當前位置:
首頁 > 新聞 > 貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

新智元推薦

本文涉及的所有代碼都可以在這裡下載:https://github.com/ypwhs/dogs_vs_cats

新智元327技術大會愛奇藝回播視頻鏈接,請點擊閱讀原文。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

本文授權轉載自Udacity知乎機構號,作者楊培文 , Udacity 機器學習項目 reviewer,特此感謝!

貓狗大戰

數據集來自 kaggle 上的一個競賽:Dogs vs. Cats,訓練集有25000張,貓狗各佔一半。測試集12500張,沒有標定是貓還是狗。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

下面是訓練集的一部分例子:

數據預處理

由於我們的數據集的文件名是以type.num.jpg這樣的方式命名的,比如cat.0.jpg,但是使用 Keras 的 ImageDataGenerator 需要將不同種類的圖片分在不同的文件夾中,因此我們需要對數據集進行預處理。這裡我們採取的思路是創建符號鏈接(symbol link),這樣的好處是不用複製一遍圖片,佔用不必要的空間。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

我們可以從下面看到文件夾的結構,train2裡面有兩個文件夾,分別是貓和狗,每個文件夾里是12500張圖。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

導出特徵向量

對於這個題目來說,使用預訓練的網路是最好不過的了,經過前期的測試,我們測試了 ResNet50 等不同的網路,但是排名都不高,現在看來只有一兩百名的樣子,所以我們需要提高我們的模型表現。那麼一種有效的方法是綜合各個不同的模型,從而得到不錯的效果,兼聽則明。如果是直接在一個巨大的網路後面加我們的全連接,那麼訓練10代就需要跑十次巨大的網路,而且我們的卷積層都是不可訓練的,那麼這個計算就是浪費的。所以我們可以將多個不同的網路輸出的特徵向量先保存下來,以便後續的訓練,這樣做的好處是我們一旦保存了特徵向量,即使是在普通筆記本上也能輕鬆訓練。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

為了復用代碼,我覺得寫一個函數是非常有必要的,那麼我們的函數就需要輸入模型,輸入圖片的大小,以及預處理函數,因為 Xception 和 Inception V3 都需要將數據限定在 (-1, 1) 的範圍內,然後我們利用 GlobalAveragePooling2D 將卷積層輸出的每個激活圖直接求平均值,不然輸出的文件會非常大,且容易過擬合。然後我們定義了兩個 generator,利用 model.predict_generator 函數來導出特徵向量,最後我們選擇了 ResNet50, Xception, Inception V3 這三個模型(如果有興趣也可以導出 VGG 的特徵向量)。每個模型導出的時間都挺長的,在 aws p2.xlarge 上大概需要用十分鐘到二十分鐘。 這三個模型都是在ImageNet上面預訓練過的,所以每一個模型都可以說是身經百戰,通過這三個老司機導出的特徵向量,可以高度概括一張圖片有哪些內容。

最後導出的 h5 文件包括三個 numpy 數組:

  • train (25000, 2048)

  • test (12500, 2048)

  • label (25000,)

如果你不想自己計算特徵向量,可以直接在這裡下載導出的文件:GitHub releases(http://t.cn/R6xSIoG)

參考資料:

  • ResNet 15.12

  • Inception v3 15.12

  • Xception 16.10

載入特徵向量

經過上面的代碼以後,我們獲得了三個特徵向量文件,分別是:

  • gap_ResNet50.h5

  • gap_InceptionV3.h5

  • gap_Xception.h5

我們需要載入這些特徵向量,並且將它們合成一條特徵向量,然後記得把 X 和 y 打亂,不然之後我們設置validation_split的時候會出問題。這裡設置了 numpy 的隨機數種子為2017,這樣可以確保每個人跑這個代碼,輸出都能是一樣的結果。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

構建模型

模型的構建很簡單,直接 dropout 然後分類就好了。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

我們還可以對模型進行可視化:

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

訓練模型

模型構件好了以後,我們就可以進行訓練了,這裡我們設置驗證集大小為 20% ,也就是說訓練集是20000張圖,驗證集是5000張圖。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

我們可以看到,訓練的過程很快,十秒以內就能訓練完,準確率也很高,在驗證集上最高達到了99.6%的準確率,這相當於一千張圖只錯了4張,可以說比我還厲害。

預測測試集

模型訓練好以後,我們就可以對測試集進行預測,然後提交到 kaggle 上看看最終成績了。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

預測這裡我們用到了一個小技巧,我們將每個預測值限制到了 [0.005, 0.995] 個區間內,這個原因很簡單,kaggle 官方的評估標準是 LogLoss,對於預測正確的樣本,0.995 和 1 相差無幾,但是對於預測錯誤的樣本,0 和 0.005 的差距非常大,是 15 和 2 的差別。參考 LogLoss 如何處理無窮大問題,下面的表達式就是二分類問題的 LogLoss 定義。

$$ extrm{LogLoss} = - frac{1}{n} sum_{i=1}^n left[ y_i log(hat{y}_i) + (1 - y_i) log(1 - hat{y}_i)
ight]$$

還有一個值得一提的地方就是測試集的文件名不是按 1, 2, 3 這樣排的,而是按下面的順序排列的:

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

因此我們需要對每個文件名進行處理,然後賦值到 df 里,最後導出為 csv 文件。

貓狗大戰識別準確率直衝 Kaggle Top 2%,手把手教你在 Keras 搭建深度 CNN

總結

我們可以從上圖中看到,模型對於前十個樣本都給出了很肯定的預測,提交到 kaggle 以後,得分也是很棒,0.04141,在全球排名中可以排到20/1314。我們如果要繼續優化模型表現,可以使用更棒的預訓練模型來導出特徵向量,或者對預訓練模型進行微調(fine-tune),或者進行數據增強(data augmentation)等。

長按二維碼關注 Udacity,獲取更多技術學習資訊

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 新智元 的精彩文章:

「乾貨」蘋果 AI 負責人 Russ Salakhutdinov 最新演講:深度生成模型定量評估(56 PPT)
博鰲論壇 2045 請回答!關於人工智慧,科技大咖們都說了些什麼?
ImageNet 2017啟幕,海康威視浦世亮談2016奪冠絕技及深度學習+安防?|新智元AI 領軍人物專訪
引入 Mobileye 技術到國產客車,硬蛋打造智能硬體全產業鏈
人機對決!騰訊圍棋AI「絕藝」電聖戰奪冠(附獲勝棋譜)

TAG:新智元 |

您可能感興趣

干货 | 手把手教你如何使用TensorFlow实现深度强化学习玩转Flappy Bird
Intel+Cloudera,用BigDL玩轉深度學習
小米 Mi A1 Android One 手機深度動手玩
文本直送科技新聞:Google IO 2017:讓行動終端也具備強大的深度學習能力,Google 推出 TensorFlow Lite 學習框架
深度學習框架TensorFlow、Caffe、MXNet、PyTorch如何抉擇?6 位大咖現身說法
MESA/Boogie Stowaway&Highwire buffer單塊深度對比解析
阿根廷Satellogic公司創始人兼CEO Emiliano Kargieman——期待與中國深度合作
Apple Watch 3深度體驗:iPhone,該你放假了!
MIT提出mNeuron:一個可視化深度模型神經元的Matlab插件
HyperX Alloy FPS Pro機械鍵盤深度體驗
Visual Studio Code 現支持深度學習/AI 應用程序
英偉達Volta架構:為深度學習而生的Tensor Core
設計神器Affinity Designer要秒殺PS?與PS、AI深度對比測評竟然是這結果!
資深演算法工程師眼中的深度學習:Ian Goodfellow 和Yoshua Bengio的「Deep Learning」讀書分享
87鍵哪兒好?HyperX Alloy FPS Pro機械鍵盤深度體驗
簡單有深度益智新游WiiU/3DS《ShootTheBall》上線
ACL 第一天:Tutorial鍾愛深度學習,唯一一個workshop關注女性群體
1小時上手 TensorFlow 深度學習應用
Angular2 VS Angular4深度對比:特性、性能