當前位置:
首頁 > 最新 > 論文筆記:批規範化

論文筆記:批規範化

一、Introduction

隨機梯度下降(SGD) 是一個訓練深度神經網路的有效方法。通常其通過優化如下loss來優化網路的參數Θ:

其中,x1…N為訓練集。在訓練時,每次使用一個大小為m的mini-batch x1…m ,該mini-batch通過如下計算來近似loss的梯度:

使用mini-batch通常有以下幾個好處:

一個mini-batch的梯度可以近似看成整個數據集的梯度,尤其是size比較大時;

並行的計算一個batch比計算m個單獨的樣本會更高效。

儘管SGD簡單高效,但是,每一層的參數更新會導致上層的輸入數據分布發生變化,隨著網路層的加深,高層的輸入分布變化會非常大,使得高層需要不斷地重新適應底層的參數更新,這就是通常所說的covariate shift。這使得我們在煉丹時需要非常小心地設置學習率以及其他初始化參數。

另外,考慮一個網路計算如下loss:

F1,F2為任意變換,Θ1,Θ2是最小化loss需要學習的網路參數。這時學習Θ2可以看作將輸入x=F1(μ,Θ1)送進一個子網路:

例如我們用如下梯度下降策略更新網路權值:

其中batch size為m,學習率為α. 此時網路的訓練可以看成在訓練一個單獨的F2,其輸入是x。這樣,隨著網路的訓練,x可能陷入一個固定的分布,Θ2不需要再進行調整來補償輸入x的分布的變化。

考慮一個使用Sigmoid激活函數的網路層z=g(Wu+b),u是輸入,W是權值矩陣,g(x)=1/(1+exp(-x))。隨著 |x|增加,g"(x)趨於0,這意味著,除了|x|很小時,梯度很容易陷入非線性飽和區。這使得網路在深層很難收斂。


Independent and Identically Distribution

獨立同分布(i.i.d.) 是指數據樣本服從同一分布並且相互獨立。這並不是所有機器學習模型的前提假設條件,但是獨立同分布的數據可以簡化模型訓練,並且預測能力更強。

Whitening

白化(Whitening)是數據預處理的一種手段,其目的是降低數據之間的冗餘性,主要包括如下兩個步驟:

去除(或降低)數據之間的相關性——>去相關性——>獨立

使所有數據具有相同的均值和方差——>e.g. 均值0,方差1——>同分布

Internal Covariate Shift (ICS)

ICS是指網路節點輸入的分布隨著訓練發生變化,因為每一層的參數更新都會導致上層的輸入數據分布發生變化。知乎回答對此做了一個很有意思的解釋:

大家都知道在統計機器學習中的一個經典假設是「源空間(source domain)和目標空間(target domain)的數據分布(distribution)是一致的」。如果不一致,那麼就出現了新的機器學習問題,如 transfer learning / domain adaptation 等。而 covariate shift 就是分布不一致假設之下的一個分支問題,它是指源空間和目標空間的條件概率是一致的,但是其邊緣概率不同,即:對所有x∈χ

但是

大家細想便會發現,的確,對於神經網路的各層輸出,由於它們經過了層內操作作用,其分布顯然與各層對應的輸入信號分布不同,而且差異會隨著網路深度增大而增大,可是它們所能「指示」的樣本標記(label)仍然是不變的,這便符合了covariate shift的定義。由於是對層間信號的分析,也即是「internal」的來由。

ICS帶來的問題,知乎回答歸納得很好:

上層參數需要不斷適應新的輸入數據分布的變化,降低學習速率;

下層輸入的變化可能趨向於過大或過下,落入飽和區,使得學習過早停止;

每層的更新都影響到其他層,因此每層的參數更新策略需要儘可能的謹慎。


以神經網路中一個神經元為例,其輸入為:

輸出為:

x的分布可能相差很大(因為ICS問題),要解決獨立同分布的問題,最好的的方法是對每一層的所有數據都進行白化操作。但是標準的白化操作代價很大,並且我們還希望白化操作是可微的,保證白化操作後可以通過反向傳播來更新梯度。

通用的Normalization方法可以看成簡化版的白化操作:

(1) 在x輸入之前先對其進行shift(平移)和scale(伸縮)變換,將x的分布規範化為在固定範圍內的標準分布。

其中μ是平移參數,δ是縮放參數。這一步使得所有數據符合均值為0,方差為1的標準分布。

(2) 但是這樣變換之後的數據很可能降低網路原始的表達能力,因為第一步中會將幾乎所有數據映射到激活函數的非飽和區(線性區)。因此進一步對數據進行re-shift(再平移)和re-scale(再縮放)變換,使數據重新獲得非線性的表達能力:

其中b是再平移參數,g是再縮放參數。這一步使得所有數據的分布確定在均值為b方差為g2的區間。

因此總體的變換如下:

這裡可以看出,Normalization只是將數據映射到一個確定的分布區間,並且未考慮去相關性的操作,因此距離標準的白化操作還很遠。

批規範化(batch normalization)顧名思義就是針對一個batch進行的normalization。如下圖所示,計算一個mini-batch的均值和方差來估計神經元輸入的均值和方差。和是需要學習的再平移和再縮放參數。

在Batch Normalization中,將每一個 mini-batch 的統計量看成是對整體統計量的近似估計,或者說認為每一個 mini-batch 彼此之間,以及和整體數據,都應該是近似同分布的。分布差距較小的 mini-batch 可以看做是為規範化操作和模型訓練引入了雜訊,可以增加模型的魯棒性;但如果每個 mini-batch的原始分布差別很大,那麼不同 mini-batch 的數據將會進行不一樣的數據變換,這就增加了模型訓練的難度。因此,BN的適用場景是:每個 mini-batch 比較大,數據分布比較接近。在進行訓練之前,要做好充分的 shuffle. 否則效果會差很多。


1. [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)

2. https://zhuanlan.zhihu.com/p/33173246


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 文本智能 的精彩文章:

論文筆記:Attention is All You Need

TAG:文本智能 |