當前位置:
首頁 > 新聞 > 【李沐】十分鐘從 PyTorch 轉 MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

MXNet 作者 / 亞馬遜主任科學家 李沐

作者:MXNet 作者 / 亞馬遜主任科學家 李沐

【新智元導讀】PyTorch 是一個純命令式的深度學習框架。它因為提供簡單易懂的編程介面而廣受歡迎,而且正在快速的流行開來。MXNet通過ndarray和 gluon模塊提供了非常類似 PyTorch 的編程介面。本文將簡單對比如何用這兩個框架來實現同樣的演算法。

【李沐】十分鐘從 PyTorch 轉 MXNet

PyTorch 是一個純命令式的深度學習框架。它因為提供簡單易懂的編程介面而廣受歡迎,而且正在快速的流行開來。例如 Caffe2 最近就併入了 PyTorch。

可能大家不是特別知道的是,MXNet 通過 ndarray 和 gluon 模塊提供了非常類似 PyTorch 的編程介面。本文將簡單對比如何用這兩個框架來實現同樣的演算法。

【李沐】十分鐘從 PyTorch 轉 MXNet


安裝

PyTorch 默認使用 conda 來進行安裝,例如

【李沐】十分鐘從 PyTorch 轉 MXNet

而 MXNet 更常用的是使用 pip。我們這裡使用了 --pre來安裝 nightly 版本

【李沐】十分鐘從 PyTorch 轉 MXNet

多維矩陣

對於多維矩陣,PyTorch 沿用了 Torch 的風格稱之為 tensor,MXNet 則追隨了 NumPy 的稱呼 ndarray。下面我們創建一個兩維矩陣,其中每個元素初始化成 1。然後每個元素加 1 後列印。

  • PyTorch:

【李沐】十分鐘從 PyTorch 轉 MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

  • MXNet:

【李沐】十分鐘從 PyTorch 轉 MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

忽略包名的不一樣的話,這裡主要的區別是 MXNet 的形狀傳入參數跟 NumPy 一樣需要用括弧括起來。


模型訓練

下面我們看一個稍微複雜點的例子。這裡我們使用一個多層感知機(MLP)來在 MINST 這個數據集上訓練一個模型。我們將其分成 4 小塊來方便對比。

讀取數據

這裡我們下載 MNIST 數據集並載入到內存,這樣我們之後可以一個一個讀取批量。

  • PyTorch:

【李沐】十分鐘從 PyTorch 轉 MXNet

  • MXNet:

【李沐】十分鐘從 PyTorch 轉 MXNet

這裡的主要區別是 MXNet 使用 transform_first 來表明數據變化是作用在讀到的批量的第一個元素,既 MNIST 圖片,而不是第二個標號元素。

定義模型

下面我們定義一個只有一個單隱層的 MLP 。

  • PyTorch:

【李沐】十分鐘從 PyTorch 轉 MXNet

  • MXNet:

【李沐】十分鐘從 PyTorch 轉 MXNet

我們使用了 Sequential 容器來把層串起來構造神經網路。這裡 MXNet 跟 PyTorch 的主要區別是:

  • 不需要指定輸入大小,這個系統會在後面自動推理得到

  • 全連接和卷積層可以指定激活函數

  • 需要創建一個 name_scope的域來給每一層附上一個獨一無二的名字,這個在之後讀寫模型時需要

  • 我們需要顯示調用模型初始化函數。

大家知道 Sequential 下只能神經網路只能逐一執行每個層。PyTorch 可以繼承 nn.Module 來自定義 forward 如何執行。同樣,MXNet 可以繼承 nn.Block 來達到類似的效果。

損失函數和優化演算法

  • PyTorch:

【李沐】十分鐘從 PyTorch 轉 MXNet

  • MXNet:

【李沐】十分鐘從 PyTorch 轉 MXNet

這裡我們使用交叉熵函數和最簡單隨機梯度下降並使用固定學習率 0.1

訓練

最後我們實現訓練演算法,並附上了輸出結果。注意到每次我們會使用不同的權重和數據讀取順序,所以每次結果可能不一樣。

  • PyTorch

【李沐】十分鐘從 PyTorch 轉 MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

  • MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

【李沐】十分鐘從 PyTorch 轉 MXNet

MXNet 跟 PyTorch 的不同主要在下面這幾點:

  • 不需要將輸入放進 Variable, 但需要將計算放在 mx.autograd.record()里使得後面可以對其求導

  • 不需要每次梯度清 0,因為新梯度是寫進去,而不是累加step的時候 MXNet 需要給定批量大小

  • 需要調用asscalar() 來將多維數組變成標量。

  • 這個樣例里 MXNet 比 PyTorch 快兩倍。當然大家對待這樣的比較要謹慎。

下一步

  • 更詳細的 MXNet 的教程:http://zh.gluon.ai/

  • 歡迎給我們留言哪些 PyTorch 的方便之處你希望 MXNet 應該也可以有

【加入社群】

新智元 AI 技術 + 產業社群招募中,歡迎對 AI 技術 + 產業落地感興趣的同學,加小助手微信號: aiera2015_1 入群;通過審核後我們將邀請進群,加入社群後務必修改群備註(姓名 - 公司 - 職位;專業群審核較嚴,敬請諒解)。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 新智元 的精彩文章:

【LeCun發聲】牛津大學專家:Facebook不算數據泄露,你們都錯了
阿里提出新圖像描述框架,解決梯度消失難題

TAG:新智元 |