當前位置:
首頁 > 新聞 > 要替代 TensorFlow?谷歌開源機器學習庫 JAX

要替代 TensorFlow?谷歌開源機器學習庫 JAX

新智元推薦

來源:AI前線(ID: ai-front)

原文: Reddit 策劃編輯: Natalie

整理: Vincent 編輯: Debra

【新智元導讀】TensorFlow有了替代品,竟然還是谷歌自己做出來的?這其實是TensorFlow的一個簡化庫,名為JAX,可以支持部分TensorFlow的功能,但是比TensorFlow更加簡潔易用。

什麼?TensorFlow 有了替代品?什麼?竟然還是谷歌自己做出來的?先別慌,從各種意義上來說,這個所謂的 「替代品」 其實是 TensorFlow 的一個簡化庫,名為JAX,結合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加簡潔易用。

雖然還不至於替代 TensorFlow,但已經有 Reddit 網友對 JAX 寄予厚望,並表示「早就期待能有一個可以直接調用 Numpy API 介面的庫了!」,「希望它可以取代 TensorFlow!」。

JAX 結合了 Autograd 和 XLA,是專為高性能機器學習研究打造的產品。

有了新版本的Autograd,JAX 能夠自動對 Python 和 NumPy 的自帶函數求導,支持循環、分支、遞歸、閉包函數求導,而且可以求三階導數。它支持自動模式反向求導(也就是反向傳播)和正向求導,且二者可以任意組合成任何順序。

JAX 的創新之處在於,它基於XLA在 GPU 和 TPU 上編譯和運行 NumPy 程序。默認情況下,編譯是在底層進行的,庫調用能夠及時編譯和執行。但是 JAX 還允許使用單一函數 API jit將自己的 Python 函數及時編譯成經過 XLA 優化的內核。編譯和自動求導可以任意組合,因此可以在不脫離 Python 環境的情況下實現複雜演算法並獲得最優性能。

JAX 最初由 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 發起,他們均任職於谷歌大腦團隊。在 GitHub 的說明文檔中,作者明確表示:JAX 目前還只是一個研究項目,不是谷歌的官方產品,因此可能會有一些 bug。從作者的 GitHub 簡介來看,這應該是谷歌大腦正在嘗試的新項目,在同一個 GitHub 目錄下的開源項目還包括 8 月份在業內引起熱議的強化學習框架 Dopamine。

以下是 JAX 的簡單使用示例。

GitHub 項目傳送門:https://github.com/google/JAX

有關具體的安裝和簡單的入門指導大家可以在 GitHub 中自行查看,在此不做過多贅述。

JAX 庫的實現原理

機器學習中的編程是關於函數的表達和轉換。轉換包括自動微分、加速器編譯和自動批處理。像 Python 這樣的高級語言非常適合表達函數,但是通常使用者只能應用它們。我們無法訪問它們的內部結構,因此無法執行轉換。

JAX 可以用於專門化高級Python+NumPy函數,並將其轉換為可轉換的表示形式,然後再提升為 Python 函數。

JAX 通過跟蹤專門處理 Python 函數。跟蹤一個函數意味著:監視應用於其輸入,以產生其輸出的所有基本操作,並在有向無環圖 (DAG) 中記錄這些操作及其之間的數據流。為了執行跟蹤,JAX 包裝了基本的操作,就像基本的數字內核一樣,這樣一來,當調用它們時,它們就會將自己添加到執行的操作列表以及輸入和輸出中。為了跟蹤這些原語之間的數據流,跟蹤的值被包裝在 Tracer 類的實例中。

當 Python 函數被提供給 grad 或 jit 時,它被包裝起來以便跟蹤並返回。當調用包裝的函數時,我們將提供的具體參數抽象到 AbstractValue 類的實例中,將它們框起來用於跟蹤跟蹤器類的實例,並對它們調用函數。

抽象參數表示一組可能的值,而不是特定的值:例如,jit 將 ndarray 參數抽象為抽象值,這些值表示具有相同形狀和數據類型的所有 ndarray。相反,grad 抽象 ndarray 參數來表示底層值的無窮小鄰域。通過在這些抽象值上跟蹤 Python 函數,我們確保它足夠專門化,以便轉換是可處理的,並且它仍然足夠通用,以便轉換後的結果是有用的,並且可能是可重用的。然後將這些轉換後的函數提升回 Python 可調用函數,這樣就可以根據需要跟蹤並再次轉換它們。

JAX 跟蹤的基本函數大多與 XLA HLO 1:1 對應,並在 lax.py 中定義。這種 1:1 的對應關係使得到 XLA 的大多數轉換基本上都很簡單,並且確保我們只有一小組原語來覆蓋其他轉換,比如自動微分。 jax.numpy 層是用純 Python 編寫的,它只是用 LAX 函數 (以及我們已經編寫的其他 numpy 函數) 表示 numpy 函數。這使得 jax.numpy 易於延展。

當你使用 jax.numpy 時,底層 LAX 原語是在後台進行 jit 編譯的,允許你在加速器上執行每個原語操作的同時編寫不受限制的 Python+ numpy 代碼。

但是 JAX 可以做更多的事情:你可以在越來越大的函數上使用jit來進行端到端編譯和優化,而不僅僅是編譯和調度到一組固定的單個原語。例如,可以編譯整個網路,或者編譯整個梯度計算和優化器更新步驟,而不僅僅是編譯和調度卷積運算。

折衷之處是,jit 函數必須滿足一些額外的專門化需求:因為我們希望編譯專門針對形狀和數據類型的跟蹤,但不是專門針對具體值的跟蹤,所以 jit 裝飾器下的 Python 代碼必須適用於抽象值。如果我們嘗試在一個抽象的 x 上求 x >0 的值,結果是一個抽象的值,表示集合 ,所以 Python 分支就像 if x > 0 會引起報錯。

有關使用 jit 的更多要求,請參見:https://github.com/google/jax#whats-supported

好消息是,jit 是可選的:JAX 庫在後台對單個操作和函數使用 jit,允許編寫不受限制的 Python+Numpy,同時仍然使用硬體加速器。但是,當你希望最大化性能時,通常可以在自己的代碼中使用 jit 編譯和端到端優化更大的函數。

後續計劃

目前項目小組還將對以下幾項做更多嘗試和更新:

完善說明文檔

支持 Cloud TPU

支持多 GPU 和多 TPU

支持完整的 NumPy 功能和部分 SciPy 功能

全面支持 vmap

加速

降低 XLA 函數調度開銷

線性代數常式(CPU 上的 MKL 和 GPU 上的 MAGMA)

高效自動微分原語cond和while

有關 JAX 庫的介紹大致如此。

再次附上 GitHub 鏈接:https://github.com/google/jax

相關資源:

JAX 論文鏈接:https://www.sysml.cc/doc/146.pdf

【加入社群】

新智元 AI 技術 + 產業社群招募中,歡迎對 AI 技術 + 產業落地感興趣的同學,加小助手微信號:aiera2015_2入群;通過審核後我們將邀請進群,加入社群後務必修改群備註(姓名 - 公司 - 職位;專業群審核較嚴,敬請諒解)。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 新智元 的精彩文章:

禁售iPhone再升級!高通尋求美國禁止進口蘋果,5G大戰英特爾躺槍
Tensorflow 2.0的這些新設計,你適應好了嗎?

TAG:新智元 |