當前位置:
首頁 > 新聞 > TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

新智元推薦

來源:AI前線(ID:ai-front)

作者:Jacob Buckman 譯者:王強、無明

【新智元導讀】本文作者Jacob來自Google AI Resident項目,他在2017年夏天開啟了為期一年的Google研究型實習,在此之前他雖然有很多編程經驗和機器學習經驗,但沒有使用過TensorFlow。這篇文章是Jacob為TensorFlow寫的一個實用教程,作者表示,要是在開啟TensorFlow學習之前有人告訴他這些知識就好了。希望這篇文章也能為讀者提供幫助,少走彎路。

TensorFlow入門必看:Google AI實習生經驗談

前言:「我叫 Jacob,是谷歌 AI Residency 項目的學者。2017 年夏天我進入這個項目的時候,我自己的編程經驗很豐富,對機器學習理解也很深刻,但以前我從未使用過 Tensorflow。當時我認為憑自己的能力可以很快掌握 Tensorflow,但沒想到我學習它的過程竟然如此跌宕起伏。甚至加入項目幾個月後我還偶爾會感到困惑,不知道怎樣用 Tensorflow 代碼實現自己的新想法。

這篇博文就像是我給過去自己寫的瓶中信:回顧當初,我希望在開始學習的時候有這樣一篇入門介紹。我也希望本文能夠幫助同行,為他們提供參考。」

過去的教程缺少哪些內容?

Tensorflow 發布已經有三年,如今它已成為深度學習生態系統的基石。然而對於初學者來說它並不怎麼簡單易懂,與 PyTorch 或 DyNet 這樣的運行即定義的神經網路庫相比就更明顯了。

有很多 Tensorflow 的入門教程,內容涵蓋線性回歸、MNIST 分類乃至機器翻譯。這些內容具體、實用的指南能幫助人們快速啟動並運行 Tensorflow 項目,並且可以作為類似項目的切入點。但有的開發者開發的應用並沒有很好的教程參考,還有的項目在探索全新的路線(研究中很常見),對於這些開發者來說入門 Tensorflow 是非常容易感到困惑的。

我寫這篇文章就想彌補這一缺口。本文不會研究某個具體任務,而是提出更加通用的方法,並解析 Tensorflow 的基礎抽象概念。掌握好這些概念後,用 Tensorflow 進行深度學習就會更加直觀易懂。

目標受眾

本教程適用於在編程和機器學習方面有一定經驗,並想要入門 Tensorflow 的從業者。他們可以是:想在深度學習課程的最後一個項目中使用 Tensorflow 的 CS 專業學生;剛剛被調到涉及深度學習的項目的軟體工程師;或者是一位處於困惑之中的 Google AI 新手(向 Jacob 大聲打個招呼吧)。如果你需要基礎知識入門,請參閱以下資源。這些都了解的話,我們就開始吧!

理解 Tensorflow

Tensorflow 不是一個普通的 Python 庫。

大多數 Python 庫被編寫為 Python 的自然擴展形式。當你導入一個庫時,你得到的是一組變數、函數和類,它們補充並擴展了你的代碼「工具箱」。使用這些庫時,你知道它們將產生怎樣的結果。我認為談及 Tensorflow 時應該拋棄這些認識,這些認知從根本上就不符合 Tensorflow 的理念,無法反映 TF 與其它代碼交互的方式。

Python 和 Tensorflow 之間的聯繫,可以類比 Javascript 和 HTML 之間的關係。Javascript 是一種全功能的編程語言,可以實現各種出色的效果。HTML 是用於表示某種類型的實用計算抽象(這裡指的是可由 Web 瀏覽器呈現的內容)的框架。Javascript 在互動式網頁中的作用是組裝瀏覽器看到的 HTML 對象,然後在需要時通過將其更新為新的 HTML 來與其交互。

與 HTML 類似,Tensorflow 是用於表示某種類型的計算抽象(稱為「計算圖」)的框架。當我們用 Python 操作 Tensorflow 時,我們用 Python 代碼做的第一件事是組裝計算圖。之後我們的第二個任務就是與它進行交互(使用 Tensorflow 的「會話」)。但重要的是,要記住計算圖不在變數內部,它處在全局命名空間內。莎士比亞曾經說過:「所有的 RAM 都是一個階段,所有的變數都只不過是指針。」

第一個關鍵抽象:計算圖

我們在瀏覽 Tensorflow 文檔時,有時會發現內容提到「圖形」和「節點」。如果你仔細閱讀、深入挖掘,甚至可能已經發現了這個頁面,該頁面中涵蓋的內容我將以更精確和技術化的風格詳細解釋。本節將從頂層入手,把握關鍵的直覺概念,同時略過一些技術細節。

那麼什麼是計算圖?它實質上是一個全局數據結構:計算圖是一個有向圖,捕獲有關計算方法的指令。

我們來看看如何構建一個示例。下圖中,上半部分是我們運行的代碼和它的輸出,下半部分是結果計算圖。

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

顯然,僅僅導入 Tensorflow 並不會給我們生成一個有趣的計算圖,而只有一個孤獨的,空白的全局變數。但是當我們調用一個 Tensorflow 操作時會發生什麼呢?

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

快看!我們得到了一個節點,它包含常量:2。我知道你很驚訝,驚訝的是一個名為 tf.constant 的函數。當我們列印這個變數時,我們看到它返回一個 tf.Tensor 對象,它是一個指向我們剛創建的節點的指針。為了強調這一點,這裡是另一個例子:

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

每次我們調用 tf.constant 的時候,我們都會在圖中創建一個新節點。即使節點在功能上與現有節點完全相同,即使我們將節點重新分配給同一個變數,甚至我們根本沒有將其分配給變數,結果都一樣。

相反,如果創建一個新變數並將其設置為與現有節點相等,則只需將該指針複製到該節點,並且不會向該圖添加任何內容:

TensorFlow入門必看:Google AI實習生經驗談

好的,我們更進一步。

TensorFlow入門必看:Google AI實習生經驗談

現在我們來看——這才是我們要的真正的計算圖表!請注意,+ 操作在 Tensorflow 中過載,所以同時添加兩個張量會在圖中增加一個節點,儘管它看起來不像是 Tensorflow 操作。

好的,所以 two_node 指向包含 2 的節點,three_node 指向包含 3 的節點,而 sum_node 指向包含... + 的節點?什麼情況?它不是應該包含 5 嗎?

事實證明,沒有。計算圖只包含計算步驟,不包含結果。至少...... 還沒有!

第二個關鍵抽象:會話

如果錯誤地理解 TensorFlow 抽象也有個瘋狂三月競賽(美國大學籃球繁忙冠軍賽季),那麼「會話」將成為每年排名第一的種子選手。能獲此尷尬的榮譽,是因為會話的命名反直覺,應用卻如此廣泛——幾乎每個 Tensorflow 程序都至少會調用一次 tf.Session 。

會話的作用是處理內存分配和優化,使我們能夠實際執行由圖形指定的計算。可以將計算圖想像為我們想要執行的計算的「模板」:它列出了所有的步驟。為了使用這個圖表,我們還需要發起一個會話,它使我們能夠實際地完成任務。例如,遍歷模板的所有節點來分配一組用於存儲計算輸出的存儲器。為了使用 Tensorflow 進行各種計算,我們既需要圖也需要會話。

會話包含一個指向全局圖的指針,該指針通過指向所有節點的指針不斷更新。這意味著在創建節點之前還是之後創建會話都無所謂。

創建會話對象後,可以使用 sess.run(node) 返回節點的值,並且 Tensorflow 將執行確定該值所需的所有計算。

TensorFlow入門必看:Google AI實習生經驗談

精彩!我們還可以傳遞一個列表,sess.run([node1,node2,...]),並讓它返回多個輸出:

TensorFlow入門必看:Google AI實習生經驗談

一般來說,sess.run 調用往往是最大的 TensorFlow 瓶頸之一,所以調用它的次數越少越好。可以的話在一個 sess.run 調用中返回多個項目,而不是進行多個調用。

佔位符和 feed_dict

我們迄今為止所做的計算一直很乏味:沒有機會獲得輸入,所以它們總是輸出相同的東西。一個實用的應用可能涉及構建這樣一個計算圖:它接受輸入,以某種(一致)方式處理它,返回一個輸出

最直接的方法是使用佔位符。佔位符是一種用於接受外部輸入的節點。

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

……這是個糟糕的例子,因為它引發了一個異常。佔位符預計會被賦予一個值,但我們沒有提供,因此 Tensorflow 崩潰了。

為了提供一個值,我們使用 sess.run 的 feed_dict 屬性。

TensorFlow入門必看:Google AI實習生經驗談

好多了。注意傳遞給 feed_dict 的數值格式。這些鍵應該是與圖中佔位符節點相對應的變數(如前所述,它實際上意味著指向圖中佔位符節點的指針)。相應的值是要分配給每個佔位符的數據元素——通常是標量或 Numpy 數組。第三個關鍵抽象:計算路徑下面是另一個使用佔位符的例子:

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

為什麼第二次調用 sess.run 會失敗?我們並沒有在檢查 input_placeholder,為什麼會引發與 input_placeholder 相關的錯誤?答案在於最終的關鍵 Tensorflow 抽象:計算路徑。還好這個抽象非常直觀。

當我們在依賴於圖中其他節點的節點上調用 sess.run 時,我們也需要計算這些節點的值。如果這些節點有依賴關係,那麼我們需要計算這些值(依此類推......),直到達到計算圖的「頂端」,也就是所有的節點都沒有前置節點的情況。

考察 sum_node 的計算路徑:

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

所有三個節點都需要評估以計算 sum_node 的值。最重要的是,這裡面包含了我們未填充的佔位符,並解釋了例外情況!

相反,考察 three_node 的計算路徑:

TensorFlow入門必看:Google AI實習生經驗談

根據圖的結構,我們不需要計算所有的節點也可以評估我們想要的節點!因為我們不需要評估 placeholder_node 來評估 three_node,所以運行 sess.run(three_node) 不會引發異常。

Tensorflow僅通過必需的節點自動路由計算這一事實是它的巨大優勢。如果計算圖非常大並且有許多不必要的節點,它就能節約大量運行時間。它允許我們構建大型的「多用途」圖形,這些圖形使用單個共享的核心節點集合根據採取的計算路徑來做不同的任務。對於幾乎所有應用程序而言,根據所採用的計算路徑考慮 sess.run 的調用方法是很重要的。

變數和副作用

到目前為止,我們已經看到兩種類型的「無祖先」節點:tf.constant(每次運行都一樣)和 tf.placeholder(每次運行都不一樣)。還有第三種節點:通常情況下具有相同的值,但也可以更新成新值。這個時候就要用到變數

了解變數對於使用 Tensorflow 進行深度學習來說至關重要,因為模型的參數就是變數。在訓練期間,你希望通過梯度下降在每個步驟更新參數,但在計算過程中,你希望保持參數不變,並將大量不同的測試輸入集傳入到模型中。模型所有的可訓練參數很有可能都是變數。

要創建變數,請使用 tf.get_variable。tf.get_variable 的前兩個參數是必需的,其餘是可選的。它們是 tf.get_variable(name,shape)。name 是一個唯一標識這個變數對象的字元串。它在全局圖中必須是唯一的,所以要確保不會出現重複的名稱。shape 是一個與張量形狀相對應的整數數組,它的語法很直觀——每個維度對應一個整數,並按照排列。例如,一個 3×8 的矩陣可能具有形狀 [3,8]。要創建標量,請使用空列表作為形狀:。

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

發現另一個異常。一個變數節點在首次創建時,它的值基本上就是「」,任何嘗試對它進行計算的操作都會拋出這個異常。我們只能先給一個變數賦值後才能用它做計算。有兩種主要方法可以用於給變數賦值:初始化器和 tf.assign。我們先看看 tf.assign:

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

與我們迄今為止看到的節點相比,tf.assign(target,value) 有一些獨特的屬性:

  • 標識操作。tf.assign(target,value) 不做任何計算,它總是與 value 相等。

  • 副作用。當計算「流經」assign_node 時,就會給圖中的其他節點帶來副作用。在這種情況下,副作用就是用保存在 zero_node 中的值替換 count_variable 的值。

  • 非依賴邊。即使 count_variable 節點和 assign_node 在圖中是相連的,兩者都不依賴於其他節點。這意味著在計算任一節點時,計算不會通過該邊迴流。不過,assign_node 依賴 zero_node,它需要知道要分配什麼。

「副作用」節點充斥在大部分 Tensorflow 深度學習工作流中,因此,請確保你對它們了解得一清二楚。當我們調用 sess.run(assign_node) 時,計算路徑將經過 assign_node 和 zero_node。

TensorFlow入門必看:Google AI實習生經驗談

當計算流經圖中的任何節點時,它還會讓該節點控制的副作用(綠色所示)起效。由於 tf.assign 的特殊副作用,與 count_variable(之前為「」)關聯的內存現在被永久設置為 0。這意味著,當我們下一次調用 sess.run(count_variable) 時,不會拋出任何異常。相反,我們將得到 0。

接下來,讓我們來看看初始化器

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

這裡都發生了什麼?為什麼初始化器不起作用?

問題在於會話和圖之間的分隔。我們已經將 get_variable 的 initializer 屬性指向 const_init_node,但它只是在圖中的節點之間添加了一個新的連接。我們還沒有做任何與導致異常有關的事情:與變數節點(保存在會話中,而不是圖中)相關聯的內存仍然為「」。我們需要通過會話讓 const_init_node 更新變數。

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

為此,我們添加了另一個特殊節點:init = tf.global_variables_initializer。與 tf.assign 類似,這是一個帶有副作用的節點。與 tf.assign 不一樣的是,我們實際上並不需要指定它的輸入!tf.global_variables_initializer 將在其創建時查看全局圖,自動將依賴關係添加到圖中的每個 tf.initializer 上。當我們調用 sess.run(init) 時,它會告訴每個初始化器完成它們的任務,初始化變數,這樣在調用 sess.run(count_variable) 時就不會出錯。

變數共享

你可能會碰到帶有變數共享的 Tensorflow 代碼,代碼有它們的作用域,並設置「reuse=True」。我強烈建議你不要在代碼中使用變數共享。如果你想在多個地方使用單個變數,只需要使用指向該變數節點的指針,並在需要時使用它。換句話說,對於打算保存在內存中的每個參數,應該只調用一次 tf.get_variable。

優化器

最後:進行真正的深度學習!如果你還在狀態,那麼其餘的概念對於你來說應該是非常簡單的。

在深度學習中,典型的「內循環」訓練如下:

  • 獲取輸入和 true_output

  • 根據輸入和參數計算出一個「猜測」

  • 根據猜測和 true_output 之間的差異計算出一個「損失」

  • 根據損失的梯度更新參數

讓我們把所有東西放在一個腳本里,解決一個簡單的線性回歸問題:

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

正如你所看到的,損失基本上沒有變化,而且我們對真實參數有了很好的估計。這部分代碼只有一兩行對你來說是新的:

TensorFlow入門必看:Google AI實習生經驗談

既然你對 Tensorflow 的基本概念已經有了很好的理解,這段代碼應該很容易解釋!第一行,optimizer = tf.train.GradientDescentOptimizer(1e-3) 不會向圖中添加節點。它只是創建了一個 Python 對象,包含了一些有用的函數。第二行 train_op = optimizer.minimize(loss),將一個節點添加到圖中,並將一個指針賦給 train_op。train_op 節點沒有輸出,但有一個非常複雜的副作用:

train_op 回溯其輸入的計算路徑,尋找變數節點。對於找到的每個變數節點,它計算與損失相關的變數梯度。然後,它為該變數計算新值:當前值減去梯度乘以學習率。最後,它執行一個賦值操作來更新變數的值。

基本上,當我們調用 sess.run(train_op) 時,它為我們對所有的變數做了一個梯度下降的操作。當然,我們還需要使用 feed_dict 來填充輸入和輸出佔位符,並且我們還希望列印這些損失,因為這樣方便調試。

用 tf.Print 進行調試

當你開始使用 Tensorflow 做更複雜的事情時,你需要進行調。一般來說,檢查計算圖中發生了什麼是很困難的。你不能使用常規的 Python 列印語句,因為你永遠無法訪問到要列印的值——它們被鎖定在 sess.run 調用中。舉個例子,假設你想檢查一個計算的中間值,在調用 sess.run 之前,中間值還不存在。但是,當 sess.run 調用返回時,中間值不見了!

我們來看一個簡單的例子。

TensorFlow入門必看:Google AI實習生經驗談

我們看到了結果是 5。但是,如果我們想檢查中間值 two_node 和 three_node,該怎麼辦?檢查中間值的一種方法是向 sess.run 添加一個返回參數,該參數指向要檢查的每個中間節點,然後在返回後列印它。

TensorFlow入門必看:Google AI實習生經驗談

這樣做通常沒有問題,但當代碼變得越來越複雜時,這可能有點尷尬。更方便的方法是使用 tf.Print 語句。令人困惑的是,tf.Print 實際上是 Tensorflow 的一種節點,它有輸出和副作用!它有兩個必需的參數:一個要複製的節點和一個要列印的內容列表。「要複製的節點」可以是圖中的任何節點,tf.Print 是與「要複製的節點」相關的標識操作,也就是說,它將輸出其輸入的副本。不過,它有個副作用,就是會列印「列印清單」里所有的值。

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

有關 tf.Print 的一個重要卻有些微妙的點:列印其實只是它的一個副作用。與所有其他副作用一樣,只有在計算流經 tf.Print 節點時才會進行列印。如果 tf.Print 節點不在計算路徑中,則不會列印任何內容。即使 tf.Print 節點正在複製的原始節點位於計算路徑上,但 tf.Print 節點本身可能不是。這個問題要注意!當這種情況發生時,它會讓你感到非常沮喪,你需要費力地找出問題所在。一般來說,最好在創建要複製的節點後立即創建 tf.Print 節點。

TensorFlow入門必看:Google AI實習生經驗談

TensorFlow入門必看:Google AI實習生經驗談

這裡(https://wookayin.github.io/tensorflow-talk-debugging/#1)有一個很好的資源,提供了更多實用的調試建議。

結 論

希望這篇文章能夠幫助你更好地理解 Tensorflow,了解它的工作原理以及如何使用它。畢竟,這裡介紹的概念對所有 Tensorflow 程序來說都很重要,但這些還都只是表面上的東西。在你的 Tensorflow 探險之旅中,你可能會遇到各種你想要使用的其他有趣的東西:條件、迭代、分散式 Tensorflow、變數作用域、保存和載入模型、多圖、多會話和多核數據載入器隊列等。

原文鏈接:

https://jacobbuckman.com/post/tensorflow-the-confusing-parts-1/#understanding-tensorflow

【加入社群】

新智元 AI 技術 + 產業社群招募中,歡迎對 AI 技術 + 產業落地感興趣的同學,加小助手微信號: aiera2015_3入群;通過審核後我們將邀請進群,加入社群後務必修改群備註(姓名 - 公司 - 職位;專業群審核較嚴,敬請諒解)。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 新智元 的精彩文章:

傳李飛飛下半年將從谷歌離職,谷歌官方回應
「獨家」寒武紀B輪估值25億美元,領跑AI晶元賽道

TAG:新智元 |