JAXnet:一行代碼定義計算圖,兼容三大主流框架,可GPU加速
機器之心整理
參與:思源、一鳴
一行代碼定義計算圖,So Easy,媽媽再也不用擔心我的機器學習。
項目地址:https://github.com/JuliusKunze/jaxnet
JAXnet 是一個基於 JAX 的深度學習庫,它的 API 提供了便利的模型搭建體驗。相比 TensorFlow 2.0 或 PyTorch 等主流框架,JAXnet 擁有獨特的優勢。舉個栗子,不論是 Keras 還是 PyTorch,它們建模就像搭積木一樣。
然而,還有一種比搭積木更簡單的方法,這就是 JAXnet 的模塊化:
創建一個全連接網路可以直接用預定義的模塊,可以說 JAXnet 定義計算圖,只需一行代碼就可以了。寫一個神經網路,原來 So easy。
總體來說,JAXnet 主要關注的是模塊化、可擴展性和易用性等幾個方面:
採用了不可變權重,而不是全局計算圖,從而獲得更強的穩健性;
用於構建神經網路、訓練循環、預處理、後處理等過程的 NumPy 代碼經過 GPU 編譯;
任意模塊或整個網路的正則化、重參數化都只需要一行代碼;
取消了全局隨機狀態,採用了更便捷的 Key 控制。
可擴展性
你可以使用 @parametrized 定義自己的模塊,並復用其它的模塊:
所有的模塊都是用這樣的方法組合在一起的。jax.numpy (https://github.com/google/jax#whats-supported) 是 numpy 的鏡像。只要你知道怎麼使用 numpy,那麼你就可以知道 JAXnet 大部分的用法了。
以下是 TensorFlow2/Keras 的代碼,JAXnet 相比之下更為簡潔:
需要注意的是,Lambda 函數在 JAXnet 中不是必要的。而 relu 和 logsoftmax 函數都是 Python 寫的函數。
非可變權重
和 TensorFlow 或者 Keras 不同,JAXnet 沒有全局計算圖。net 和 loss 這樣的模塊不保存可變權重。權重則是保存在分開的不可變類中。這些權重由 init_parameters 函數初始化,用於提供隨機的鍵和樣本輸入:
目標函數不會在線變更權重,而是不斷更新權重的下一個版本。它們會以新的優化狀態返回,並由 get_parameters 取回。
當需要對網路進行評價時:
JAXnet 的正則化也十分簡單:
其他特性
除了簡潔的代碼,JAXnet 還支持在 GPU 上進行計算。而且還可以用 jit 進行編譯,擺脫 Python 運行緩慢的問題。同時,JAXnet 是單步調試的,和 Python 代碼一樣。
安裝也十分簡單,使用 pip 安裝即可。如果需要使用 GPU,則需要先安裝 jaxlib。
其他具體的 API 可參考:https://github.com/JuliusKunze/jaxnet/blob/master/API.md
本文為機器之心整理,轉載請聯繫本公眾號獲得授權。
------------------------------------------------


※CV困境如何破:訓練樣本有限、2D視覺平面 VS 3D真實場景……
※Jupyter Notebook界面也可以如此炫酷?有人把Notebook玩出新花樣
TAG:機器之心 |