當前位置:
首頁 > 知識 > 手把手教你用DGL框架進行批量圖分類

手把手教你用DGL框架進行批量圖分類

機器之心專欄

者:DGL團隊

圖分類(預測圖的標籤)是圖結構數據里一類重要的問題。它的應用廣泛,可見於生物信息學、化學信息學、社交網路分析、城市計算以及網路安全。隨著近來學界對於圖神經網路的熱情持續高漲,出現了一批用圖神經網路做圖分類的工作。比如訓練圖神經網路來預測蛋白質結構的性質,根據社交網路結構來預測用戶的所屬社區等(Ying et al., 2018, Cangea et al., 2018, Knyazev et al., 2018, Bianchi et al., 2019, Liao et al., 2019, Gao et al., 2019)。

在這個教程里,我們將一起學習:

如何使用 DGL 批量化處理大小各異的圖數據

訓練圖神經網路完成一個簡易的圖分類任務

簡易圖分類任務

這裡我們設計了一個簡單的圖分類任務。在 DGL 里我們實現了一個迷你圖分類數據集(MiniGCDataset)。它由以下 8 類圖結構數據組成。每一類圖包含同樣數量的隨機樣本。任務目標是訓練圖神經網路模型對這些樣本進行分類。

以下是使用 MiniGCDataset 的示例代碼。我們先創建了一個擁有 80 個樣本的數據集。數據集中每張圖隨機有 10 到 20 個節點。DGL 中所有的數據集類都符合 Sequence 的抽象結構——既可以使用 dataset[i] 來訪問第 i 個樣本。這裡每個樣本包含圖結構以及它對應的標籤。

運行以上代碼後可以畫出數據集中第一個樣本的圖結構以及它對應的標籤:

打包一個圖的小批量

為了更高效地訓練神經網路,一個常見的做法是將多個樣本打包成小批量(mini-batch)。打包尺寸相同的張量樣本非常簡單。比如說打包兩個尺寸為 2828 的圖片會得到一個 22828 的張量。相較之下,打包圖面臨兩個挑戰:

圖的邊比較稀疏

圖的大小、形狀各不相同

DGL 提供了名為 dgl.batch 的介面來實現打包一個圖批量的功能。其核心思路非常簡單。將 n 張小圖打包在一起的操作可以看成是生成一張含 n 個不相連小圖的大圖。下圖的可視化從直覺上解釋了 dgl.batch 的功能。

可以看到通過 dgl.batch 操作,我們生成了一張大圖,其中包含了一個環狀和一個星狀的連通分量。其鄰接矩陣表示則對應為在對角線上把兩張小圖的鄰接矩陣拼接在一起(其餘部分都為 0)。

以下是使用 dgl.batch 的一個實際例子。我們定義了一個 collate 函數來將 MiniGCDataset 里多個樣本打包成一個小批量。

正如打包 N 個張量得到的還是張量,dgl.batch 返回的也是一張圖。這樣的設計有兩點好處。首先,任何用於操作一張小圖的代碼可以被直接使用在一個圖批量上。其次,由於 DGL 能夠並行處理圖中節點和邊上的計算,因此同一批量內的圖樣本都可以被並行計算。

圖分類器

這裡使用的圖分類器和應用在圖像或者語音上的分類器類似——先通過多層神經網路計算每個樣本的表示(representation),再通過表示計算出每個類別的概率,最後通過向後傳播計算梯度。一個常見的圖分類器由以下幾個步驟構成:

通過圖卷積(Graph Convolution)層獲得圖中每個節點的表示。

使用「讀出」操作(Readout)獲得每張圖的表示。

使用 Softmax 計算每個類別的概率,使用向後傳播更新參數。

下圖展示了整個流程:

之後我們將分步講解每一個步驟。

圖卷積

我們的圖卷積操作基本類似圖卷積網路 GCN(具體可以參見我們的關於 GCN 的教程)。圖卷積模型可以用以下公式表示:

在這個例子中,我們對這個公式進行了微調:

我們將求和替換成求平均可用來平衡度數不同的節點,在實驗中這也帶來了模型表現的提升。

此外,在構建數據集時,我們給每個圖裡所有的節點都加上了和自己的邊(自環)。這保證節點在收集鄰居節點表示進行更新時也能考慮到自己原有的表示。以下是定義圖卷積模型的代碼。這裡我們使用PyTorch作為 DGL 的後端引擎(DGL 也支持 MXNet 作為後端)。

首先,我們使用 DGL 的內置函數定義消息傳遞:

其次,我們定義消息累和函數。這裡我們對收到的消息進行平均。

之後,我們對收到的消息應用線性變換和激活函數。

最後,我們把所有的小模塊串聯起來成為 GCNLayer。

讀出和分類

讀出(Readout)操作的輸入是圖中所有節點的表示,輸出則是整張圖的表示。在 Google 的 Neural Message Passing for Quantum Chemistry(Gilmer et al. 2017) 論文中總結過許多不同種類的讀出函數。在這個示例里,我們對圖中所有節點表示取平均以作為圖的表示:

DGL 提供了許多讀出函數介面,以上公式可以很方便地用 dgl.mean(g) 完成。最後我們將圖的表示輸入分類器。分類器對圖表示先做了一個線性變換然後得到每一類在 softmax 之前的 logits。具體代碼如下:

準備和訓練

閱讀到這邊的讀者可以長舒一口氣了。因為之後的訓練過程和其他經典的圖像,語音分類問題基本一致。首先我們創建了一個包含 400 張節點數量為 10~20 的合成數據集。其中 320 張圖作為訓練數據集,80 張圖作為測試集。

其次我們創建一個剛剛定義的圖神經網路模型對象。

訓練過程則是經典的反向傳播和梯度下降。

下圖是以上模型訓練的學習曲線:

在訓練完成後,我們在測試集上驗證模型的表現。出於部署教程的考量,我們限制了模型訓練的時間。如果你花更多時間訓練模型,應該能得到更好的表現(80%-90%)。

我們還製作了一個動畫來展示訓練好的模型預測每張圖真實標籤的概率。可以看到我們剛剛定義的圖神經網路能夠較為準確地預測出圖樣本的對應標籤:

為了更好地理解模型學到的節點和圖的表示,我們使用了 t-SNE 來進行降維和可視化。

頂部的兩張小圖分別可視化了做完 1 層和 2 層圖卷積後的節點表示。不同顏色代表屬於不同類別的圖的節點。可以看到,經過訓練後,屬於同一類別的節點表示更加接近。並且,經過兩層圖卷積後這一聚類效果更明顯。其原因是因為兩層卷積後每個節點能接收到 2 度範圍內的鄰居信息。

底部的大圖可視化了每張圖在做 softmax 前的 logits,也就是圖表示。可以看到通過讀出函數後,圖表示能非常好地各自區分開來。這一區分度比節點表示更加明顯。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

還在PS裏手動描邊?AI自動摳圖只需5秒
十八歲華裔天才攜手「量子計算先驅」再次顛覆量子計算

TAG:機器之心 |