當前位置:
首頁 > 知識 > 機器學習中用來防止過擬合的方法有哪些?

機器學習中用來防止過擬合的方法有哪些?

機器學習中用來防止過擬合的方法有哪些?



給《機器視覺與應用》課程出大作業的時候,正好涉及到這方面內容,所以簡單整理了一下(參考 Hinton 的課程)。按照之前的套路寫:


是什麼


過擬合(overfitting)是指在模型參數擬合過程中的問題,由於訓練數據包含抽樣誤差,訓練時,複雜的模型將抽樣誤差也考慮在內,將抽樣誤差也進行了很好的擬合。

具體表現就是最終模型在訓練集上效果好;在測試集上效果差。模型泛化能力弱。

機器學習中用來防止過擬合的方法有哪些?



為什麼


為什麼要解決過擬合現象?這是因為我們擬合的模型一般是用來預測未知的結果(不在訓練集內),過擬合雖然在訓練集上效果好,但是在實際使用時(測試集)效果差。同時,在很多問題上,我們無法窮盡所有狀態,不可能將所有情況都包含在訓練集上。所以,必須要解決過擬合問題。


為什麼在機器學習中比較常見?這是因為機器學習演算法為了滿足儘可能複雜的任務,其模型的擬合能力一般遠遠高於問題複雜度,也就是說,機器學習演算法有「擬合出正確規則的前提下,進一步擬合雜訊」的能力。


而傳統的函數擬合問題(如機器人系統辨識),一般都是通過經驗、物理、數學等推導出一個含參模型,模型複雜度確定了,只需要調整個別參數即可。模型「無多餘能力」擬合雜訊。


怎麼樣


既然過擬合這麼討厭,我們應該怎麼防止過擬合呢?最近深度學習比較火,我就以神經網路為例吧:

機器學習中用來防止過擬合的方法有哪些?


1. 獲取更多數據


這是解決過擬合最有效的方法,只要給足夠多的數據,讓模型「看見」儘可能多的「例外情況」,它就會不斷修正自己,從而得到更好的結果:

機器學習中用來防止過擬合的方法有哪些?



如何獲取更多數據,可以有以下幾個方法:


從數據源頭獲取更多數據:這個是容易想到的,例如物體分類,我就再多拍幾張照片好了;但是,在很多情況下,大幅增加數據本身就不容易;另外,我們不清楚獲取多少數據才算夠;


根據當前數據集估計數據分布參數,使用該分布產生更多數據:這個一般不用,因為估計分布參數的過程也會代入抽樣誤差。


數據增強(Data Augmentation):通過一定規則擴充數據。如在物體分類問題里,物體在圖像中的位置、姿態、尺度,整體圖片明暗度等都不會影響分類結果。我們就可以通過圖像平移、翻轉、縮放、切割等手段將資料庫成倍擴充;

機器學習中用來防止過擬合的方法有哪些?


2. 使用合適的模型


前面說了,過擬合主要是有兩個原因造成的:數據太少 + 模型太複雜。所以,我們可以通過使用合適複雜度的模型來防止過擬合問題,讓其足夠擬合真正的規則,同時又不至於擬合太多抽樣誤差。


(PS:如果能通過物理、數學建模,確定模型複雜度,這是最好的方法,這也就是為什麼深度學習這麼火的現在,我還堅持說初學者要學掌握傳統的建模方法。)


對於神經網路而言,我們可以從以下四個方面來限制網路能力


2.1 網路結構 Architecture


這個很好理解,減少網路的層數、神經元個數等均可以限制網路的擬合能力;

機器學習中用來防止過擬合的方法有哪些?



2.2 訓練時間 Early stopping

對於每個神經元而言,其激活函數在不同區間的性能是不同的:

機器學習中用來防止過擬合的方法有哪些?



當網路權值較小時,神經元的激活函數工作在線性區,此時神經元的擬合能力較弱(類似線性神經元)。


有了上述共識之後,我們就可以解釋為什麼限制訓練時間(early stopping)有用:因為我們在初始化網路的時候一般都是初始為較小的權值。訓練時間越長,部分網路權值可能越大。如果我們在合適時間停止訓練,就可以將網路的能力限制在一定範圍內。


2.3 限制權值 Weight-decay,也叫正則化(regularization)


原理同上,但是這類方法直接將權值的大小加入到 Cost 里,在訓練的時候限制權值變大。以 L2 regularization 為例:


訓練過程需要降低整體的 Cost,這時候,一方面能降低實際輸出與樣本之間的誤差 C,也能降低權值大小。


2.4 增加雜訊 Noise


給網路加雜訊也有很多方法:

2.4.1 在輸入中加雜訊:


雜訊會隨著網路傳播,按照權值的平方放大,並傳播到輸出層,對誤差 Cost 產生影響。推導直接看 Hinton 的 PPT 吧:

機器學習中用來防止過擬合的方法有哪些?



在輸入中加高斯雜訊,會在輸出中生成


的干擾項。訓練時,減小誤差,同時也會對雜訊產生的干擾項進行懲罰,達到減小權值的平方的目的,達到與 L2 regularization 類似的效果(對比公式)。


2.4.2 在權值上加雜訊


在初始化網路的時候,用 0 均值的高斯分布作為初始化。Alex Graves 的手寫識別 RNN 就是用了這個方法


Graves, Alex, et al. "A novel connectionist system for unconstrained handwriting recognition." IEEE transactions on pattern analysis and machine intelligence 31.5 (2009): 855-868.


- It may work better, especially in recurrent networks (Hinton)

2.4.3 對網路的響應加雜訊


如在前向傳播過程中,讓默寫神經元的輸出變為 binary 或 random。顯然,這種有點亂來的做法會打亂網路的訓練過程,讓訓練更慢,但據 Hinton 說,在測試集上效果會有顯著提升 (But it does significantly better on the test set!)。


3. 結合多種模型


簡而言之,訓練多個模型,以每個模型的平均輸出作為結果。


從 N 個模型里隨機選擇一個作為輸出的期望誤差


,會比所有模型的平均輸出的誤差


大(我不知道公式里的圓括弧為什麼顯示不了):

機器學習中用來防止過擬合的方法有哪些?



大概基於這個原理,就可以有很多方法了:

3.1 Bagging


簡單理解,就是分段函數的概念:用不同的模型擬合不同部分的訓練集。以隨機森林(Rand Forests)為例,就是訓練了一堆互不關聯的決策樹。但由於訓練神經網路本身就需要耗費較多自由,所以一般不單獨使用神經網路做 Bagging。


3.2 Boosting


既然訓練複雜神經網路比較慢,那我們就可以只使用簡單的神經網路(層數、神經元數限制等)。通過訓練一系列簡單的神經網路,加權平均其輸出。

機器學習中用來防止過擬合的方法有哪些?



3.3 Dropout


這是一個很高效的方法。

機器學習中用來防止過擬合的方法有哪些?


在訓練時,每次隨機(如 50% 概率)忽略隱層的某些節點;這樣,我們相當於隨機從 2^H 個模型中採樣選擇模型;同時,由於每個網路只見過一個訓練數據(每次都是隨機的新網路),所以類似bagging的做法,這就是我為什麼將它分類到「結合多種模型」中;


此外,而不同模型之間權值共享(共同使用這 H 個神經元的連接權值),相當於一種權值正則方法,實際效果比 L2 regularization 更好。


4. 貝葉斯方法


這部分我還沒有想好怎麼才能講得清楚,為了不誤導初學者,我就先空著,以後如果想清楚了再更新。當然,這也是防止過擬合的一類重要方法。

機器學習中用來防止過擬合的方法有哪些?



綜上:

機器學習中用來防止過擬合的方法有哪些?



喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器學習 的精彩文章:

亞馬遜:機器學習對Kiva機器人的價值非常大!
美媒發布重磅白皮書「機器學習與深度學習的浪潮」
戴爾為女性創業者發布「Hello Alice」機器學習平台
谷歌宣布正式推出基於雲GPU的雲機器學習引擎
機器學習和深度學習引用量最高的20篇論文

TAG:機器學習 |

您可能感興趣

機器學習中「正則化來防止過擬合」到底是一個什麼原理?
怎樣才能有效的防止腦梗死?我這裡有最全的預防方法!
防止腦萎縮的有效鍛煉方法
我們為了防止羊口炎都有哪些方法?
門牙突出快矯正方法有哪些 怎樣防止牙齒受損
防止貓咪嘔吐的方法都有哪些?
健身房的器材不會用?我來給你普及普及用法,防止尷尬
為防止兵馬俑受到破壞,人們想了哪些方法?這法子雖然丑,但有用
家庭教育:用這1個方法,能有效防止孩子「沉迷手機」!
防止被下屬架空的四個經典方法!用過的人都說管用!
乾貨:用兩種方法備份你的蘋果手機,防止手機丟失或損壞!
為防止戰機飛行員叛逃,我國竟然用這種做法,也真是絕了
防止走光的好方法
怎麼防止中風複發,有3種好方法
領導防止自己被下屬架空的解決方案,用過的人都說好用!
全世界各國的尖端戰機,幾乎全部運用這種方法,成功防止了飛行員駕機出逃
使用天然方法防止螞蟻來襲
專家:掌握這個方法,能有效防止癌症複發
來學習一下概率論基本知識,它能讓防止你的模型過擬合
民間驅邪的方式有哪些?為了防止中邪,果斷收藏了!