當前位置:
首頁 > 知識 > 想要算一算Wasserstein距離?這裡有一份PyTorch實戰

想要算一算Wasserstein距離?這裡有一份PyTorch實戰

選自dfdazac

作者:Daniel Daza

機器之心編譯

最優傳輸理論及 Wasserstein 距離是很多讀者都希望了解的基礎,本文主要通過簡單案例展示了它們的基本思想,並通過 PyTorch 介紹如何實戰 W 距離。

機器學習中的許多問題都涉及到令兩個分布儘可能接近的思想,例如在 GAN 中令生成器分布接近判別器分布就能偽造出逼真的圖像。但是 KL 散度等分布的度量方法有很多局限性,本文則介紹了 Wasserstein 距離及 Sinkhorn 迭代方法,它們 GAN 及眾多任務上都展示了傑出的性能。

在簡單的情況下,我們假設從未知數據分布 p(x) 中觀測到一些隨機變數 x(例如,貓的圖片),我們想要找到一個模型 q(x|θ)(例如一個神經網路)能作為 p(x) 的一個很好的近似。如果 p 和 q 的分布很相近,那麼就表明我們的模型已經學習到如何識別貓。

因為 KL 散度可以度量兩個分布的距離,所以只需要最小化 KL(q‖p) 就可以了。可以證明,最小化 KL(q‖p) 等價於最小化一個負對數似然,這樣的做法在我們訓練一個分類器時很常見。例如,對於變分自編碼器來說,我們希望後驗分布能夠接近於某種先驗分布,這也是我們通過最小化它們之間的 KL 散度來實現的。

儘管 KL 散度有很廣泛的應用,在某些情況下,KL 散度則會失效。不妨考慮一下如下圖所示的離散分布:

KL 散度假設這兩個分布共享相同的支撐集(也就是說,它們被定義在同一個點集上)。因此,我們不能為上面的例子計算 KL 散度。由於這一個限制和其他計算方面的因素促使研究人員尋找一種更適合於計算兩個分布之間差異的方法。

在本文中,作者將:

簡單介紹最優傳輸問題

將 Sinkhorn 迭代描述為對解求近似

使用 PyTorch 計算 Sinkhorn 距離

描述用於計算 mini-batch 之間的距離的對該實現的擴展

移動概率質量函數

我們不妨把離散的概率分布想像成空間中分散的點的質量。我們可以觀測這些帶質量的點從一個分布移動到另一個分布需要做多少功,如下圖所示:

接著,我們可以定義另一個度量標準,用以衡量移動做所有點所需要做的功。要想將這個直觀的概念形式化定義下來,首先,我們可以通過引入一個耦合矩陣 P(coupling matrix),它表示要從 p(x) 支撐集中的一個點上到 q(x) 支撐集中的一個點需要分配多少概率質量。對於均勻分布,我們規定每個點都具有 1/4 的概率質量。如果我們將本例支撐集中的點從左到右排列,我們可以將上述的耦合矩陣寫作:

也就是說,p(x) 支撐集中點 1 的質量被分配給了 q(x) 支撐集中的點 4,p(x) 支撐集中點 2 的質量被分配給了 q(x) 支撐集中的點 3,以此類推,如上圖中的箭頭所示。

為了算出質量分配的過程需要做多少功,我們將引入第二個矩陣:距離矩陣。該矩陣中的每個元素 C_ij 表示將 p(x) 支撐集中的點移動到 q(x) 支撐集中的點上的成本。點與點之間的歐幾里得距離是定義這種成本的一種方式,它也被稱為「ground distance」。如果我們假設 p(x) 的支撐集和 q(x) 的支撐集分別為 和 ,成本矩陣即為:

根據上述定義,總的成本可以通過 P 和 C 之間的 Frobenius 內積來計算:

你可能已經注意到了,實際上有很多種方法可以把點從一個支撐集移動到另一個支撐集中,每一種方式都會得到不同的成本。上面給出的只是一個示例,但是我們感興趣的是最終能夠讓成本較小的分配方式。這就是兩個離散分布之間的「最優傳輸」問題,該問題的解是所有耦合矩陣上的最低成本 L_C。

由於不是所有矩陣都是有效的耦合矩陣,最後一個條件會引入了一個約束。對於一個耦合矩陣來說,其所有列都必須要加到帶有 q(x) 概率質量的向量中。在本例中,該向量包含 4 個值為 1/4 的元素。更一般地,我們可以將兩個向量分別記為 a 和 b,因此最有運輸問題可以被寫作:

當距離矩陣基於一個有效的距離函數構建時,最小成本即為我們所說的「Wasserstein 距離」。

關於該問題的解以及將其擴展到連續概率分布中還有大量問題需要解決。如果想要獲取更正式、更容易理解的解釋,讀者可以參閱 Gabriel Peyré 和 Marco Cuturi 編寫的「Computational Optimal Transport」一書,此書也是本文寫作的主要參考來源之一。

這裡的基本設定是,我們已經把求兩個分布之間距離的問題定義為求最優耦合矩陣的問題。事實證明,我們可以通過一個小的修改讓我們以迭代和可微分的方式解決這個問題,這將讓我們可以很好地使用深度學習自動微分機制完成該工作。

熵正則化和 Sinkhorn 迭代

首先,我們將一個矩陣的熵定義如下:

正如資訊理論中概率分布的熵一樣,一個熵較低的矩陣將會更稀疏,它的大部分非零值集中在幾個點周圍。相反,一個具有高熵的矩陣將會更平滑,其最大熵是在均勻分布的情況下獲得的。我們可以將正則化係數 ε 引入最優傳輸問題,從而得到更平滑的耦合矩陣:

通過增大 ε,最終得到的耦合矩陣將會變得更加平滑;而當 ε 趨近於零時,耦合矩陣會更加稀疏,同時最終的解會更加趨近於原始最優運輸問題。

通過引入這種熵正則化,該問題變成了一個凸優化問題,並且可 以通過使用「Sinkhorn iteration」求解。解可以被寫作 P=diag(u)Kdiag(v),在迭代過程中交替更新 u 和 v:

其中 K 是一個用 C 計算的核矩陣(kernel matrix)。由於這些迭代過程是在對原始問題的正則化版本求解,因此對應產生的 Wasserstein 距離有時被稱為 Sinkhorn 距離。該迭代過程會形成一個線性操作的序列,因此對於深度學習模型,通過這些迭代進行反向傳播是非常簡單的。

通過 PyTorch 實現 Sinkhorn 迭代

為了提升 Sinkhorn 迭代的收斂性和穩定性,還可以加入其它的步驟。我們可以在 GitHub 上找到 Gabriel Peyre 完成的詳細實現。

項目鏈接:https://github.com/gpeyre/SinkhornAutoDiff。

讓我們先用一個簡單的例子來測試一下,現在我們將研究二維空間(而不是上面的一維空間)中的離散均勻分布。在這種情況下,我們將在平面上移動概率質量。讓我們首先定義兩個簡單的分布:

我們很容易看出,最優傳輸對應於將 p(x) 支撐集中的每個點分配到 q(x) 支撐集上的點。對於所有的點來說,距離都是 1,同時由於分布是均勻的,每點移動的概率質量是 1/5。因此,Wasserstein 距離是 5×1/5= 1。現在我們用 Sinkhorn 迭代來計算這個距離:

結果正如我們所計算的那樣,距離為 1。現在,讓我們查看一下「Sinkhorn( )」方法返回的矩陣,其中 P 是計算出的耦合矩陣,C 是距離矩陣。距離矩陣如下圖所示:

元素「C[0, 0]」說明了將(0,0)點的質量移動到(0,1)所需要的成本 1 是如何產生的。在該行的另一端,元素「C[0, 4]」包含了將點(0,0)的質量移動到點(4,1)所需要的成本,這個成本是整個矩陣中最大的:

由於我們為距離矩陣使用的是平方後的 ?2 範數,計算結果如上所示。現在,讓我們看看計算出的耦合矩陣吧:

該圖很好地向我們展示了演算法是如何有效地發現最優耦合,它與我們前面確定的耦合矩陣是相同的。到目前為止,我們使用了 0.1 的正則化係數。如果將該值增加到 1 會怎樣?

正如我們前面討論過的,加大 ε 有增大耦合矩陣熵的作用。接下來,我們看看 P 是如何變得更加平滑的。但是,這樣做也會為計算出的距離帶來一個不好的影響,導致對 Wasserstein 距離的近似效果變差。

可視化支撐集的空間分配也很有意思:

讓我們在一個更有趣的分布(Moons 數據集)上完成這項工作。

Mini-batch 上的 Sinkhorn 距離

在深度學習中,我們通常對使用 mini-batch 來加速計算十分感興趣。我們也可以通過使用額外的批處理維度修改 Sinkhorn 迭代來滿足該設定。將此更改添加到具體實現中後,我們可以在一個 mini-batch 中計算多個分布的 Sinkhorn 距離。下面我們將通過另一個容易被驗證的例子說明這一點。

代碼:https://github.com/dfdazac/wassdistance/blob/master/layers.py

我們將計算包含 5 個支撐點的 4 對均勻分布的 Sinkhorn 距離,它們垂直地被 1(如上所示)、2、3 和 4 個單元分隔開。這樣,它們之間的 Wasserstein 距離將分別為 1、4、9 和 16。

這樣做確實有效!同時,也請注意,現在 P 和 C 為 3 維張量,它包含 mini-batch 中每對分布的耦合矩陣和距離矩陣:

結語

分布之間的 Wasserstein 距離及其通過 Sinkhorn 迭代實現的計算方法為我們帶來了許多可能性。該框架不僅提供了對 KL 散度等距離的替代方法,而且在建模過程中提供了更大的靈活性,我們不再被迫要選擇特定的參數分布。這些迭代過程可以在 GPU 上高效地執行,並且是完全可微分的,這使得它對於深度學習來說是一個很好的選擇。這些優點在機器學習領域的最新研究中得到了充分的利用(如自編碼器和距離嵌入),使其在該領域的應用前景更加廣闊。

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

------------------------------------------------


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

機器學習的七大謠傳,這都是根深蒂固的執念吧
Python與PHP的對決:誰是工程師最喜歡和最討厭的語言?

TAG:機器之心 |