當前位置:
首頁 > 最新 > IBM研究院提出Graph2Seq,基於注意力機制的圖到序列學習

IBM研究院提出Graph2Seq,基於注意力機制的圖到序列學習

作者:IBM Research

編譯:weakish

介紹

Seq2Seq(序列到序列)及其變體在機器翻譯、自然語言生成、語音識別、新葯發現之類的領域表現非常出色。大多數Seq2Seq模型都屬於編碼器-解碼器家族,其中編碼器將輸入序列編碼為固定維度的連續向量表示,而解碼器則解碼向量得到目標序列。

然而,Seq2Seq有一個限制,它只能應用於輸入表示為序列的問題。而在許多問題中,輸入為更複雜的結構,比如圖(graph)。對於這類圖到序列(graph-to-sequence)問題,如果要應用Seq2Seq,就需要將圖轉換為序列。然而,將圖精確地轉換為序列是一項艱巨的挑戰,因為在將圖這種比較複雜的結構數據轉換為序列時,難免會損失不少信息,特別是當輸入數據本身適合用圖表示的時候。最近的一些研究嘗試在輸入數據中提取句法特徵,例如句子的片語結構(Tree2Seq),或將注意力機制應用於輸入集(Set2Seq),或將句子遞歸地編碼為樹(Tree-LSTM)。在特定類別問題上,這類方法取得了充滿希望的結果,然而,這類方法大多難以推廣。

為此,IBM研究院的Kun Xu、Lingfei Wu等提出了Graph2Seq,一個端到端的處理圖到序列問題的模型。

Graph2Seq採用與Seq2Seq相似的編碼器-解碼器架構,包括一個圖編碼器和一個序列解碼器。圖編碼器部分,通過聚合有向圖和無向圖中的相鄰信息,學習節點嵌入。然後根據學習到的節點嵌入,構建圖嵌入。序列解碼器部分,論文作者設計了一個基於注意力機制的LSTM網路,使用圖嵌入作為初始隱藏狀態,輸出目標預測。注意力機制用於學習節點和序列元素的對齊,以更好地應對大型圖。整個Graph2Seq的設計是模塊化的,可擴展性很好。比如,編碼器可以換成圖卷積網路,解碼器可以換成普通的LSTM。

Graph2Seq模型

在上一節的末尾,我們已經簡單介紹了Graph2Seq的架構。這一節我們將具體介紹Graph2Seq模型。下面是Graph2Seq的整體架構示意圖。

節點嵌入生成

如前所述,節點嵌入中包含了節點的相鄰信息。具體的嵌入生成過程如下:

通過查詢嵌入矩陣We,將節點v的文本屬性轉換為一個特徵向量av。

根據邊的方向,將v的鄰居分類為前向鄰(forward neighbor)N|-(v)和反向鄰(backward neighbor)N-|(v)。

將v的前向鄰的前向表示

聚合為單個向量

其中k為迭代索引。注意,在迭代k時,聚合僅僅使用k-1時生成的表示。每個節點的初始化前向表示為其特徵向量。

我們將v的當前前向表示(k-1)和新生成的前向聚合向量(k)連接。連接所得的向量傳入一個帶非線性激活的全連接層,從而更新v的前向表示,在下一次迭代中使用。

將上述過程應用於反向表示。

重複前向表示聚合與反向表示聚合過程K次,連接最終的前向表示和反向表示,作為v的最終表示。

用偽代碼表示以上節點嵌入生成過程:

上面我們提到了聚合前向表示和反向表示,卻沒有提到具體的聚合方法。實際上,論文作者嘗試了3種不同的聚合方法。

均值這是最簡單直接的聚合方式,取分素均值(element-wise mean)。

LSTM使用LSTM處理節點鄰居的單個隨機排列(無序集)。

池化將每個鄰居向量傳入一個全連接網路,然後應用分素最大池化(element-wise max-pooling)。

其中,σ為非線性激活函數。

經論文作者試驗,總體而言,最簡單的均值聚合效果最好。

均值(MA)、LSTM(LA)、池化(PA)聚合在3個合成SDP數據集(有向無環圖、有向有環圖、序線圖)上的精確度

圖嵌入生成

論文作者引入了兩種基於節點嵌入構造圖嵌入的方法。

基於池化的圖嵌入。類似上面基於池化的聚合,論文作者將節點嵌入傳給一個全連接神經網路,然後分素應用池化方法。論文作者共試驗了三種池化方法,最大池化、最小池化、平均池化,最後發現三種池化方法沒有顯著差別。因此,論文作者最後選用了最大池化作為默認的池化方法。

基於節點的圖嵌入。這一方法加入了一個超(super)節點vs至輸入圖,使圖中的所有其他節點指向vs。我們使用之前提到的節點嵌入生成演算法生成vs嵌入,因而vs嵌入捕獲了所有節點的信息,可視為圖嵌入。

經論文作者試驗,總體而言,基於池化的圖嵌入表現較好。

基於注意力的解碼器

序列解碼器是一個基於注意力的LSTM網路,根據給定的y1,...,yi-1,隱藏狀態si(i表示時刻),以及上下文向量ci,預測下一個token,即yi。其中,上下文向量ci取決於前述圖編碼器根據輸入圖生成的節點表示集合(z1,...,zv)。具體而言,上下文向量ci通過節點表示的加權和計算得出:

相應的權重aij由下式計算得出:

其中,a為對齊模型(alignment model),為j處的輸入節點和i處的輸出的匹配程度評分。評分基於LSTM的隱藏狀態si-1和輸入圖的第j個節點表示。對齊模型a為前饋神經網路,和系統的其他部分一起訓練。

試驗

試驗設定

論文作者使用了Adam優化,mini-batch大小為30,學習率為0.001,解碼器層dropout率為0.5(避免過擬合)。norm大於20時裁剪梯度。圖編碼器部分,默認跳(hop)大小為6,節點初始特徵向量為40,非線性激活函數為ReLU,聚合器的參數隨機初始化。解碼器為單層,隱藏狀態大小為80. 如前所述,使用了表現最佳的均值聚合和基於池化的圖嵌入生成。

試驗結果

從上表可以看到,在bAbI Task 19上,LSTM失敗了,而Graph2Seq的表現是最好的,超過了GGS-NN和GCN。

而在最短路徑任務(Shortest Path Task)上,LSTM同樣失敗了。儘管GGS-NN、GCN、Graph2Seq在小數據集上(SP-S,節點尺寸=5)上都達到了100%的精確度,但在大數據上(SP-L,節點尺寸=100),得益於解碼器部分注意力機制的應用,Graph2Seq的表現超過了GGS-NN和GCN。

最後,論文作者在自然語言生成(Natural Language Generation)任務上評估了Graph2Seq的表現。具體而言,這一任務根據SQL查詢語句,生成描述其含義的自然語言。論文作者使用的是WikiSQL數據集,該數據集包含87726對手工標註的自然語言查詢問題,SQL查詢,以及相應的SQL表。WikiSQL原本是為評測問題回答任務而創建的,這裡論文作者逆向使用該數據集,將SQL請求視作輸入,將生成正確的英語問題視作目標。WikiSQL的SQL請求分割為訓練、驗證、測試集,分別包含61297、9145、17284個請求。

從上表可以看出,Graph2SQL的BLEU-4評分顯著高於Seq2Seq、Seq2Seq + Copy、Tree2Seq。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 論智 的精彩文章:

FaceForensics:一個用於人臉偽造檢測的大型視頻數據集
從十年前的這篇論文看如今「中國芯」的舉步維艱

TAG:論智 |