當前位置:
首頁 > 新聞 > 一圖勝千言: 解讀阿里的Deep Image CTR Model

一圖勝千言: 解讀阿里的Deep Image CTR Model

雷鋒網 AI 科技評論按:本文作者石塔西,原載於知乎,雷鋒網已獲授權。

本文是對阿里的論文《Image Matters: Visually modeling user behaviors using Advanced Model Server》的解讀。

初讀此文的標題和摘要,又有 image,又有 CTR,我以為是一種新型的 CNN+MLP 的聯合建模方法。讀下來才知道,本文的重點絕不在什麼圖像建模上,壓根就沒 CNN 什麼事。想像中的像素級別的建模根本沒有出現,商品的圖片利用網上可下載的預訓練好的 VGG16 模型的某個中間層壓縮成 4096 維向量,作為 CTR 模型的原始輸入。

而將圖片引入到推薦/搜索領域,也不是什麼新鮮事。不說論文 Related Works 中提到的工作,我自己就做過基於圖片的向量化召回,結構與論文圖 4 中的 Pre-Rank DICM 結構很相似,只不過用戶側不包含他之前點擊過的商品圖片罷了,在此略下不表。

沒有提出新的圖像建模方法,也並非第一次在推薦演算法中使用圖片信息,那麼此文的創新點到底在哪裡?我覺得,本文的創新點有兩個創新點:

  1. 之前的工作儘管也在推薦/搜索演算法中引入了圖片信息,可是那些圖片只用於物料側,用於豐富商品、文章的特徵表示。而阿里的這篇論文,是第一次將圖片用於用戶側建模,基於用戶歷史點擊過的圖片(user behavior images)來建模用戶的視覺偏好

  2. 接下來會看到,將圖片加入到用戶側建模,理論上並不複雜,理論上用傳統 PS 也可以實現,起碼跑個實驗,發篇論文應該不成問題。但是,如果應用到實際系統,圖片特徵引入的大數據量成為技術瓶頸。為此,阿里團隊為傳統 PS 的 server 也增加了「模型訓練」功能,並稱新結構為 Advanced Model Server(AMS)

基於歷史點擊圖片建模用戶視覺偏好

先談一下第一個「小創新」。之所以說其「小」,是因為通過預訓練的 CNN 模型提取特徵後,每張圖片用一個高維(比如 4096)稠密向量來表示。這些圖片向量,與常見的稀疏 ID 類特徵經過 embedding 得到的稠密向量,沒有質的區別(量的區別,下文會提到),完全可以復用以前處理 ID embedding 的方法(如 pooling, attention)來處理

Deep Image CTR Model(DICM)的具體結構如下所示:

一圖勝千言: 解讀阿里的Deep Image CTR Model

打開今日頭條,查看更多圖片

DICM 架構圖

  • 如果只看左邊,就是推薦/搜索中常見的 Embedding+MLP 結構。注意上圖中的 Embedding+MLP 結構只是實際系統的簡化版本,實際系統中可以替換成 Wide&Deep, DIN, DIEN 等這些「高大上」的東西。

  • 假設一個滿足要求的圖片 embedding model 已經 ready,即圖中的 embmodel。商品的縮略圖,經過 embmodel 壓縮,得到商品的圖片信息(圖中的粉紅色塊)

  • 右邊部分,負責利用圖片建模用戶。將每個用戶點擊過的圖片(user behavior image),經過 embmodel 進行壓縮(圖中的藍色塊)。它們與商品圖片(ad image)的 embedding 結果(粉紅色塊)經過attentive pooling合併成一個向量(桔色塊)表示用戶的視覺偏好

  • 將用戶點擊過的多張圖片的向量(藍色)合併成一個向量(桔色),其思路與 Deep Interest Network 基於 attention 的 pooling 機制大同小異,只不過要同時考慮「id 類特徵」與「商品圖片」對用戶歷史點擊圖片的 attention,稱為MultiQueryAttentivePooling

  • 第 1 步得到基於 id 特徵的 embedding 結果,與第 2 步得到的商品圖片 (ad image) 的 embedding 結果(粉紅色),與第 3 步得到的表示用戶興趣偏好的向量(桔紅色),拼接起來,傳入 MLP,進行充分的交互

這個模型的優勢在於:

  • 之前的模型只考慮了傳統的 ID 類特徵和物料的圖像信息,這次加入了用戶的視覺偏好,補齊了一塊信息短板

  • 不僅如此,通過 MLP,將傳統的 ID 類特徵、物料的圖像信息、用戶的視覺偏好進行充分交互,能夠發現更多的 pattern。

  • 基於用戶歷史訪問的 item id 來建模用戶的興趣,始終有「冷啟動」問題。如果用戶訪問過一個 embedding matrix 中不存在的 item,這部分信息只能損失掉。而基於用戶歷史訪問的圖片來建模,類似於 content-based modeling,商品雖然是新的,但是其使用的圖片與模型之前見過的圖片卻很相似,從而減輕了「冷啟動」問題。

綜上可見,DICM 的思路、結構都很簡單。但是,上面的描述埋了個大伏筆:那個圖片嵌入模型 embmodel 如何設計?沒有加入圖片、只有稀疏的 ID 類特徵時,Embedding+MLP 可以通過 Parameter Server 來分散式訓練。現在這個 embmodel,是否還可以在 PS 上訓練?在回答這個問題之前,讓我們先看看稀疏 ID 特徵 Embedding+MLP 在傳統的 PS 上是如何訓練的?

稀疏 ID 特徵 Embedding+MLP 在傳統的 PS 上是如何訓練的?

介紹 PS 的論文、博客汗牛充棟,實在論不上我在這裡炒冷飯,但是,我還是要將我實踐過的「基於 PS 訓練的 DNN 推薦演算法」,在這裡簡單介紹一下,因為我覺得它與《Scaling Distributed Machine Learning with the Parameter Server》所介紹的「經典」PS 還是稍稍有所不同,與同行們探討。

基於 PS 的分散式訓練的思想還是很簡單的:

1.一開始是 data parallelism。每台 worker 只利用本地的訓練數據前代、回代,計算 gradient,並發往 server。Server 匯總(平均)各 worker 發來的 gradient,更新模型,並把更新過的模型同步給各 worker。這裡有一個前提,就是數據量超大,但是模型足夠小,單台 server 的內存足以容納

2.但是,推薦/搜索系統使用超大規模的 LR 模型,模型參數之多,已經是單台 server 無法容納的了。這時 Parameter Server 才應運而生,它同時結合了 data parallelism 與 model parallelism

  • Data parallelism:訓練數據依然分布地存儲在各台 worker node 上,各 worker node 也只用本地數據進行計算。

  • Model parallelism:一來模型之大,單台 server 已經無法容納,所以多台 server 組成一個分散式的 key-value 資料庫,共同容納、更新模型參數;二來,由於推薦/搜索的特徵超級稀疏,各 worker 上的訓練數據只涵蓋了一部分特徵,因此每個 worker 與 server 之間也沒有必要同步完整模型,而只需要同步該 worker 的本地訓練數據所能夠涵蓋的那一部分模型

所以按照我的理解,PS 最擅長的是訓練稀疏數據集上的演算法,比如超大規模 LR 的 CTR 預估。但是,基於 DNN 的推薦/搜索演算法,常見模式是稀疏 ID 特徵 Embedding+MLP,稍稍有所不同

1.稀疏 ID 特徵 Embedding,是使用 PS 的理想對象:超大的 embedding 矩陣已經無法容納於單台機器中,需要分散式的 key-value 資料庫共同存儲;數據稀疏,各 worker 上的訓練數據只涵蓋一部分 ID 特徵,自然也只需要和 server 同步這一部分 ID 的 embedding 向量。

2.MLP 部分,稍稍不同

  • 和計算機視覺中動輒幾百層的深網路相比,根據我的經驗,縱使工業級別的推薦/搜索演算法,MLP 也就是 3~4 層而已,否則就有過擬合的風險。這等「小淺網路」可以容納於單台機器的內存中,不需要分散式存儲

  • 與每台 worker 只需要與 server 同步本地所需要的部分 embedding 不同,MLP 是一個整體,每台 worker 都需要與 server 同步完整 MLP 的全部參數,不會只同步局部模型。

所以,在我的實踐中

  • 稀疏 ID 特徵 Embedding,就是標準的 PS 做法,用 key-value 來存儲。Key 就是 id feature,value 就是其對應的 embedding 向量;

  • MLP 部分,我用一個 KEY_FOR_ALL_MLP 在 server 中存儲 MLP 的所有參數(一個很大,但單機足以容納的向量),以完成 worker 之間對 MLP 參數的同步

實際上,對 Embedding 和 MLP 不同特性的論述,在《Deep Interest Network for Click-Through Rate Prediction》中也有所論述。阿里的 X-DeepLearning 平台

  • 用 Distributed Embedding Layer 實現了分散式的 key-value 資料庫來存儲 embedding。應該是標準的 PS 做法。

  • 用 Local Backend 在單機上訓練 MLP。如何實現各 worker(i.e., local backend)的 MLP 的同步?是否和我的做法類似,用一個 key 在 server 上存儲 MLP 的所有參數?目前尚不得而知,還需要繼續研究

加入圖片特徵後,能否繼續在 PS 上訓練?

按原論文的說法,自然是不能,所以才提出了 AMS。一開始,我以為」PS 不支持圖片」是「質」的不同,即 PS 主要針對稀疏特徵,而圖片是稠密數據。但是,讀完文章之後,發現之前的想法是錯誤的,稀疏 ID 特徵與圖片特徵在稀疏性是統一的。

  • 某個 worker node 上訓練樣本集,所涵蓋的 item id 與 item image,只是所有 item ids/images 的一部分,從這個角度來說,item id/image 都是稀疏的,為使用 PS 架構提供了可能性

  • item image 經過 pre-trained CNN model 預處理,參與 DICM 訓練時,已經是固定長度的稠密向量。Item id 也需要 embedding 成稠密向量。從這個角度來說,item id/image 又都是稠密的

正因為稀疏 ID 特徵與圖片特徵,本質上沒有什麼不同,因此 PS 無須修改,就可以用於訓練包含圖片特徵的 CTR 模型(起碼理論上行得通),就是文中所謂的 store-in-server 模式。

  • 圖片特徵存入 PS 中的 server,key 是 image index,value 是經過 VGG16 提取出來的稠密向量

  • 訓練數據存放在各 worker 上,其中圖片部分只存儲 image index

  • 訓練中,每個 worker 根據各自本地的訓練集所包含的 image index,向 server 請求各自所需的 image 的 embedding,訓練自己的 MLP

一切看上去很美好,直到我們審視 VGG16 提取出來的 image embedding 到底有多長

  • 原論文中提到,經過試驗,阿里團隊最終選擇了 FC6 的輸出,是一個 4096 長的浮點數向量。而這僅僅是一張圖片,每次迭代中,worker/server 需要通信的數據量是 mini-batch size * 單用戶歷史點擊圖片數 (i.e., 通常是幾十到上百) * 4096 個浮點數。按照原論文中 table 2 的統計,那是 5G 的通訊量。

  • 而一個 ID 特徵的 embedding 才用 12 維的向量來表示。也就是說,引入 image 後,通訊量增長了 4096/12=341 倍

(或許有心的讀者問,既然 4096 的 image embedding 會造成如此大的通訊壓力,那為什麼不選擇 vgg16 中小一些層的輸出呢?因為 vgg16 是針對 ImageNet 訓練好的,而 ImageNet 中的圖片與淘寶的商品圖片還是有不小的差距(淘寶的商品圖片應該很少會出現海象與鴨嘴獸吧),因此需要提取出來的 image embedding 足夠長,以更好地保留一些原始信息。原論文中也嘗試過提取 1000 維的向量,性能上有較大損失。)

正是因為原始圖片 embedding 太大了,給通信造成巨大壓力,才促使阿里團隊在 server 上也增加了一個「壓縮」模型,從而將 PS 升級為 AMS。

AMS 的技術細節,將在下一節詳細說明。這裡,我覺得需要強調一下,由於加入圖片而需要在 AMS,而不是 PS 上訓練,這個變化是「量」變引起的,而不是因為原來的 ID 特徵與圖片這樣的多媒體特徵在「質」上有什麼不同。比如,在這個例子中,

  • 使用 AMS 是因為 image 的原始 embedding 由 4096 個浮點數組成,太大了

  • 之所以需要 4096 個浮點數,是因為 vgg16 是針對 ImageNet 訓練的,與淘寶圖片相差較大,所以需要保留較多的原始信息

  • 如果淘寶專門訓練一個針對商品圖片的分類模型,那麼就有可能拿某個更接近 loss 層、更小的中間層的輸出作為 image embedding

  • 這樣一來,也就沒有通信壓力了,也就無需 server 上的「壓縮」模型了,傳統的 PS 也就完全可以勝任了。

所以,AMS 並不應該是接入多媒體特徵後的唯一選擇,而 AMS 也不僅僅是針對多媒體特徵才有用。應該說,AMS 應該是針對「embedding 過大、佔有過多帶寬」的解決方案之一

Advanced Model Server(AMS)架構

上一節講清楚了,AMS 是為了解決「image 的原始 embedding 過大,造成太大通信壓力」的問題而提出的。在這一節里,我們來看看 AMS 是如何解決這一問題的。

AMS 的解決方案也很簡單:

  • 為每個 server 增加一個可學習的「壓縮」模型(論文中的 sub-model,其實就是一個 4096-256-64-12 的金字塔型的 MLP

  • 當 worker 向 server 請求 image embedding 時,server 上的「壓縮」模型先將原始的 4096 維的 image embedding 壓縮成 12 維,再傳遞給 worker,從而將通訊量減少到原來的 1/340

  • 該「壓縮」模型的參數,由每個 server 根據存在本地的圖片數據學習得到,並且在一輪迭代結束時,各 server 上的「壓縮」模型需要同步

每個 server 上都有這樣一個這個可學習的「壓縮」模型,要能夠利用存放在本地的數據(這裡就是 4096 長的 image 原始 embedding)前代、回代、更新權重,並且各 server 的模型還需要同步,簡直就是 worker 上模型的翻版。將 worker 的「訓練模型」的功能複製到 server,這也就是 Advanced Model Server 相比於傳統 Parameter Server 的改進之處。

AMS 是本文最大的創新點。本來還想再費些筆墨詳細描述,最後發現不過是對原論文 4.2 節的翻譯,白白浪費篇幅罷了,請讀者移步原論文。其實,當你明白了 AMS 要解決什麼樣的問題,那麼原論文中的解決方案,也就是一層窗戶紙罷了,簡單來說,就是將 worker 上的模型前代、回代、更新、同步代碼移植到 server 端罷了。最後加上原論文中的圖 2,以做備忘。

一圖勝千言: 解讀阿里的Deep Image CTR Model

AMS 交互流程

總結

以上就是我對 Deep Image CTR Model(DICM)兩個創新點的理解。根據原論文,無論是離線實驗還是線上 AB 測試,DICM 的表現都比不考慮用戶視覺偏好的老模型要更加優異。DICM 開啟了在推薦系統中引入多媒體特徵的新篇章

小結一下 DICM 的成就與思路:

  • DICM,第一次將圖片信息引入到用戶側建模,通過用戶歷史上點擊過的圖片(user behavior images)建模用戶的視覺偏好,而且將傳統的 ID 類特徵、物料的圖像信息、用戶的視覺偏好進行充分交互,能夠發現更多的 pattern,也解決了只使用 ID 特徵而帶來的冷啟動問題。

  • 但是,引入 user behavior images 後,由於 image 原始 embedding 太大,給分散式訓練時的通信造成了巨大壓力。為此,阿里團隊通過給每個 server 增加一個可學習的「壓縮」模型,先壓縮 image embedding 再傳遞給 worker,大大降低了 worker/server 之間的通信量,使 DICM 的效率能夠滿足線上系統的要求。這種為 server 增加「模型訓練」功能的 PS,被稱為 AMS。

最後,還應該強調,引發 PS 升級到 AMS 的驅動力,是「量變」而不是「質變」。圖片之類的多媒體特徵,既不是 AMS 的唯一用武之地,也不應是 AMS 壟斷的專利。選擇哪種訓練架構,需要我們根據業務、數據的特點做出判斷,切忌迷信「銀彈」。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 雷鋒網 的精彩文章:

回家吧 回到最初的美好 網易新聞打造「極致」聯動直播溫暖回家路
鬧元宵對成語搭上區塊鏈?!

TAG:雷鋒網 |