當前位置:
首頁 > 知識 > JAXnet:一行代碼定義計算圖,兼容三大主流框架,可GPU加速

JAXnet:一行代碼定義計算圖,兼容三大主流框架,可GPU加速

機器之心整理

參與:思源、一鳴

一行代碼定義計算圖,So Easy,媽媽再也不用擔心我的機器學習。

項目地址:https://github.com/JuliusKunze/jaxnet

JAXnet 是一個基於 JAX 的深度學習庫,它的 API 提供了便利的模型搭建體驗。相比 TensorFlow 2.0 或 PyTorch 等主流框架,JAXnet 擁有獨特的優勢。舉個栗子,不論是 Keras 還是 PyTorch,它們建模就像搭積木一樣。

然而,還有一種比搭積木更簡單的方法,這就是 JAXnet 的模塊化:

創建一個全連接網路可以直接用預定義的模塊,可以說 JAXnet 定義計算圖,只需一行代碼就可以了。寫一個神經網路,原來 So easy。

總體來說,JAXnet 主要關注的是模塊化、可擴展性和易用性等幾個方面:

採用了不可變權重,而不是全局計算圖,從而獲得更強的穩健性;

用於構建神經網路、訓練循環、預處理、後處理等過程的 NumPy 代碼經過 GPU 編譯;

任意模塊或整個網路的正則化、重參數化都只需要一行代碼;

取消了全局隨機狀態,採用了更便捷的 Key 控制。

可擴展性

你可以使用 @parametrized 定義自己的模塊,並復用其它的模塊:

所有的模塊都是用這樣的方法組合在一起的。jax.numpy (https://github.com/google/jax#whats-supported) 是 numpy 的鏡像。只要你知道怎麼使用 numpy,那麼你就可以知道 JAXnet 大部分的用法了。

以下是 TensorFlow2/Keras 的代碼,JAXnet 相比之下更為簡潔:

需要注意的是,Lambda 函數在 JAXnet 中不是必要的。而 relu 和 logsoftmax 函數都是 Python 寫的函數。

非可變權重

和 TensorFlow 或者 Keras 不同,JAXnet 沒有全局計算圖。net 和 loss 這樣的模塊不保存可變權重。權重則是保存在分開的不可變類中。這些權重由 init_parameters 函數初始化,用於提供隨機的鍵和樣本輸入:

目標函數不會在線變更權重,而是不斷更新權重的下一個版本。它們會以新的優化狀態返回,並由 get_parameters 取回。

當需要對網路進行評價時:

JAXnet 的正則化也十分簡單:

其他特性

除了簡潔的代碼,JAXnet 還支持在 GPU 上進行計算。而且還可以用 jit 進行編譯,擺脫 Python 運行緩慢的問題。同時,JAXnet 是單步調試的,和 Python 代碼一樣。

安裝也十分簡單,使用 pip 安裝即可。如果需要使用 GPU,則需要先安裝 jaxlib。

其他具體的 API 可參考:https://github.com/JuliusKunze/jaxnet/blob/master/API.md

本文為機器之心整理,轉載請聯繫本公眾號獲得授權。

------------------------------------------------

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

CV困境如何破:訓練樣本有限、2D視覺平面 VS 3D真實場景……
Jupyter Notebook界面也可以如此炫酷?有人把Notebook玩出新花樣

TAG:機器之心 |