當前位置:
首頁 > 新聞 > 遷移學習不好懂?這裡有一個PyTorch項目幫你理解

遷移學習不好懂?這裡有一個PyTorch項目幫你理解

新智元報道

來源:Medium

編輯:元子

【新智元導讀】遷移學習是一個非常重要的機器學習技術,已被廣泛應用於機器學習的許多應用中。本文的目標是讓讀者理解遷移學習的意義,了解轉學習的重要性,並學會使用PyTorch進行實踐。

前幾天新智元介紹了在線元學習,以及元獎勵學習。元學習有一個非常重要的理念是在較少樣本量的情況下,讓機器能夠自己學會學習。

這一點和遷移學習非常相似。吳恩達曾經說過"遷移學習將會是繼監督學習之後的下一個機器學習商業成功的驅動力"。

相比而言,依賴大量數據進行訓練的其他機器學習手段(例如昨天新智元報道的GPipe),對數據和算力的依賴有點過於嚴重。況且,數據和算力那麼貴!

遷移學習的一大特色,就是「將一個任務環境中學到的東西用來提升在另一個任務環境中模型的泛化能力」。

沒有GPU也沒關係,可以使用谷歌的免費GPU服務,通過谷歌Colab來訓練模型。

在TensorFlow 2.0即將發布之際,就讓我們一起通過PyTorch來更直觀、更深入的了解遷移學習。

前期準備

本次旅程,我們將使用預先訓練的網路,來構建用於瘧疾檢測的圖像分類器,這個分類器只需要將得到的數據,分為「感染」「未感染」兩類。

我們將要用到的圖像數據集可以在這裡下載

https://drive.google.com/open?id=16DbIOMCtCuRuMdYF64MPv3iLqpSG6tfv

經過預先訓練的網路在ImageNet上進行了訓練,其中包含120萬張1000個類別的圖像,

用到的模型是torchvision.models,它有6種不同的架構我們可以使用。

torchvision.models具有模型性能的細分以及可以使用的層數(由模型附帶的數字表示)。

載入所有必需的包和庫:

將數據進行可視化:

下圖是感染的圖

定義轉換並載入進數據

轉換是將一個圖形、表達式或函數轉換為另一個圖形、表達式或函數的過程。

我們需要為訓練、測試以及驗證數據定義一些轉換。值得注意的,可能有的類別圖像太少,不夠進行轉換,為了增加網路識別的圖像數量,我們執行所謂的數據增強。

在訓練期間,我們隨機裁剪、縮放和旋轉圖像,以便在每個時期,網路會看到同一圖像的不同變化,提高實驗的準確性。

接下來載入數據集。最簡單的方法是用torchvision的dataset.ImageFolder。

載入imageFolder後,我們將數據拆分為20%驗證集和10%測試集; 然後將它傳遞給DataLoader。

它接收一個類似從ImageFolder獲得的數據集,並返回批量圖像及其相應的標籤(可以將改組設置為true以在時期內引入變化)。

模型訓練流程

1. 載入預先訓練的模型

PyTorch以及幾乎所有其他深度學習框架,都使用CUDA來有效地計算GPU上的前向和後向傳遞。

在PyTorch中,我們使用model.cuda()將模型參數和其他張量移動到GPU內存,或者從GPU移回,

2. 凍結卷積層並使用自定義分類器替換完全連接的層

凍結模型參數允許我們為早期卷積層保留預訓練模型的權重,其目的是用於特徵提取。

然後我們定義我們的全連接網路,他將作為輸入神經元,示例代碼中是1024,這個數字取決於預訓練模型的輸入神經元,和自定義隱藏層。

我們還定義了要使用的激活函數,和有助於通過隨機關閉層中的神經元,以強制在剩餘節點之間共享信息,來避免過度擬合的丟失。

在我們定義了自定義全連接網路之後,我們將其連接到預先訓練好的模型的完全連接網路。

接下來我們定義損失函數,優化器,並通過將模型移動到GPU來準備訓練模型。

3. 為特定任務訓練自定義分類器

在訓練期間,我們遍歷每個時期的DataLoader。 對於每個batch,使用標準函數計算損失。使用loss.backward()方法計算相對於模型參數的損失梯度。

optimizer.zero_grad()負責清除任何累積的梯度,因為我們會一遍又一遍地計算梯度。

optimizer.step()使用具有動量的隨機梯度下降(Adam)更新模型參數。

為了防止過度擬合,我們使用一種稱為早期停止的強大技術。背後的想法很簡單,當驗證數據集上的性能開始降低時停止訓練。

在耐心地等待訓練過程完成並保存最佳模型參數的檢查點之後,讓我們載入檢查點並在看不見的數據(測試數據)上測試模型的性能。

從磁碟載入已保存的模型

在看不見的數據上測試載入的模型。 我們對看不見的數據有90%的準確率,這在第一次嘗試時非常令人印象深刻。

現在我們對模型有了信心,現在是時候進行一些預測並將結果可視化了。

好。教程到此就結束了。我們使用PyTorch,利用遷移學習建立了一個瘧疾分類器的應用。

接下來,我們可以繼續的完善代碼,或者可以再做幾個其他同類型的應用。

參考鏈接:

https://heartbeat.fritz.ai/transfer-learning-with-pytorch-cfcb69016c72

【加入社群】

新智元AI技術 產業社群招募中,歡迎對AI技術 產業落地感興趣的同學,加小助手微信號:aiera2015_2入群;通過審核後我們將邀請進群,加入社群後務必修改群備註(姓名 - 公司 - 職位;專業群審核較嚴,敬請諒解)。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 新智元 的精彩文章:

谷歌、DeepMind重磅推出PlaNet,數據效率提升50倍
49必須了解的機器學習開源項目,Github上平均3600星

TAG:新智元 |