TensorFlow 中層 API TFRecordDataset
本文作者YJango,本文首發於知乎專欄【超智能體】,AI 研習社獲其授權轉載。
半年沒有更新了, 由於抑鬱,我把 gitbook 上的《超智能體》電子書刪掉了,所有以 gitbook 作為資源所顯示的圖片以及所有引向 gitbook 的鏈接全部失效。CSDN 上是不能看了。
遇到圖片和鏈接失效的朋友到我知乎的專欄里找相應的文章,如果沒有了只能說聲抱歉。
很欣慰還有人喜歡我寫的文章,以及對超智能體專欄(http://t.cn/RJvw4iB)的支持。
有很多想說的,卻又不知道說什麼。只是,謝謝。以往 YJango 的文章都是以教學為主,並不覆蓋高效的實際應用。
這篇文章是專門寫給那些支持過我的讀者們,感謝你們。
完整代碼可以從下面的 github(http://t.cn/R8yKP99) 上找到。
目錄
前言
優勢
Dataset API
TFRecord
概念
數據說明
數據存儲
常用存儲
TFRecord 存儲
實現
生成數據
寫入 TFRecord file
存儲類型
如何存儲張量 feature
使用 Dataset
創建 dataset
操作 dataset
解析函數
迭代樣本
Shuffle
Batch
Batch padding
Epoch
幫助函數
前言
半年沒有更新了, 由於抑鬱,我把 gitbook 上的《超智能體》電子書刪掉了,所有以 gitbook 作為資源所顯示的圖片以及所有引向 gitbook 的鏈接全部失效。CSDN 上是不能看了。
遇到圖片和鏈接失效的朋友到我知乎的專欄里找相應的文章,如果沒有了只能說聲抱歉。
很欣慰還有人喜歡我寫的文章,以及對超智能體專欄(http://t.cn/RJvw4iB)的支持。
有很多想說的,卻又不知道說什麼。只是,謝謝。以往 YJango 的文章都是以教學為主,並不覆蓋高效的實際應用。
這篇文章是專門寫給那些支持過我的讀者們,感謝你們。
完整代碼可以從下面的 github(http://t.cn/R8yKP99) 上找到。
優勢
一、為什麼用 Dataset API?
1. 簡潔性:
常規方式:用 python 代碼來進行 batch,shuffle,padding 等 numpy 類型的數據處理,再用 placeholder + feed_dict 來將其導入到 graph 中變成 tensor 類型。因此在網路的訓練過程中,不得不在 tensorflow 的代碼中穿插 python 代碼來實現控制。
Dataset API:將數據直接放在 graph 中進行處理,整體對數據集進行上述數據操作,使代碼更加簡潔。
2. 對接性:TensorFlow 中也加入了高級 API (Estimator、Experiment,Dataset)幫助建立網路,和 Keras 等庫不一樣的是:這些 API 並不注重網路結構的搭建,而是將不同類型的操作分開,幫助周邊操作。可以在保證網路結構控制權的基礎上,節省工作量。若使用 Dataset API 導入數據,後續還可選擇與 Estimator 對接。
二、為什麼用 TFRecord?
在數據集較小時,我們會把數據全部載入到內存里方便快速導入,但當數據量超過內存大小時,就只能放在硬碟上來一點點讀取,這時就不得不考慮數據的移動、讀取、處理等速度。使用 TFRecord 就是為了提速和節約空間的。
概念
在進行代碼功能講解之前,先明確一下想要存儲和讀取的數據是什麼樣子(老手跳過)。
一、數據說明:
假設要學習判斷個人收入的模型。我們會事先搜集反映個人信息的輸入,用這些信息作為判斷個人收入的依據。同時也會把擁有的人的實際收入也搜集。這樣搜集n個人的後形成我們的數據集。
1. 訓練:在每一步訓練中,神經網路會把輸入和 正確的輸出送入中來更新一次神經網路f()中的參數 θ。用很多個不同的不斷更新 θ,最終希望當遇到新的時,可以用判斷出正確的。
2. 專有名詞:結合下圖說明名稱
樣本 (example)::輸入和正確的輸出一起叫做樣本。給網路展示了什麼輸入該產生什麼樣的輸出。這裡每個是五維向量,每個是一維向量。
表徵 (representation)::集合了代表個人的全部特徵。
特徵 (feature):中的某個維:如年齡,職業。是某人的一個特點。
標籤 (label)::正確的輸出。
一個樣本 (an example)
二、數據存儲
為達成上述的訓練,我們需要把所有的樣本存儲成合適的類型以供隨後的訓練。
1. 常用存儲:
輸入和標籤是分開存儲,若有 100 個樣本,所有的輸入存儲成一個 100×5 的 numpy 矩陣;所有的輸出則是100×1。
2. TFRecord 存儲:
TFRecord 是以字典的方式一次寫一個樣本,字典的 keys 可以不以輸入和標籤,而以不同的特徵(如學歷,年齡,職業,收入)區分,在隨後的讀取中再選擇哪些特徵形成輸入,哪些形成標籤。這樣的好處是,後續可以根據需要只挑選特定的特徵;也可以方便應對例如多任務學習這樣有多個輸入和標籤的機器學習任務。
註:一般而言,單數的 feature 是一個維度,即標量。所有的 features 組成 representation。但在 TFRecord 的存儲中,字典中 feature 的 value 可以不是標量。如:key 為學歷的 value 就可以是:[初中,高中,大學],3 個 features 所形成的向量。亦可是任何維度的張量。
實現
一、生成數據
除了標量和向量外,feature 有時會是矩陣(如段落),有時會還會是三維張量(如圖片)。
所以這裡展示如何寫入三個樣本,每個樣本有四個 feature,分別是標量,向量,矩陣,三維張量(圖片)。
1. 導入庫包
2. 生成數據
三個樣本的數值是遞增的,方便認清順序
顯示結果
二、寫入 TFRecord file
1. 打開 TFRecord file
2. 創建樣本寫入字典
這裡準備一個樣本一個樣本的寫入 TFRecord file 中。
先把每個樣本中所有 feature 的信息和值存到字典中,key 為 feature 名,value 為 feature 值。
feature 值需要轉變成 tensorflow 指定的 feature 類型中的一個:
2.1. 存儲類型
int64:
float32:
string:
註:輸入必須是 list(向量)
2.2. 如何處理類型是張量的 feature
tensorflow feature 類型只接受 list 數據,但如果數據類型是矩陣或者張量該如何處理?
兩種方式:
轉成 list 類型:將張量 fatten 成 list(也就是向量),再用寫入 list 的方式寫入。
轉成 string 類型:將張量用. tostring() 轉換成 string 類型,再用 tf.train.Feature(bytes_list=tf.train.BytesList(value=[input.tostring()])) 來存儲。
形狀信息:不管那種方式都會使數據丟失形狀信息,所以在向該樣本中寫入 feature 時應該額外加入 shape 信息作為額外 feature。shape 信息是 int 類型,這裡我是用原 feature 名字 +"_shape"來指定 shape 信息的 feature 名。
3. 轉成 tf_features
4. 轉成 tf_example
5. 序列化樣本
6. 寫入樣本
7. 關閉 TFRecord file
三、使用 Dataset
1. 創建 dataset
Dataset 是你的數據集,包含了某次將要使用的所有樣本,且所有樣本的結構需相同(在 tensorflow 官網介紹中,樣本 example 也被稱作 element)。樣本需從 source 導入到 dataset 中,導入的方式有很多中。隨後也可從已有的 dataset 中構建出新的 dataset。
1.1. 直接導入(非本文重點,隨後不再提)
1.2. 從 TFRecord 文件導入
2. 操作 dataset:
如優勢中所提到的,我們希望對 dataset 中的所有樣本進行統一的操作(batch,shuffle,padding 等)。接下來就是對 dataset 的操作。
2.1. dataset.map(func)
由於從 tfrecord 文件中導入的樣本是剛才寫入的 tf_serialized 序列化樣本,所以我們需要對每一個樣本進行解析。這裡就用 dataset.map(parse_function) 來對 dataset 里的每個樣本進行相同的解析操作。
註:dataset.map(輸入) 中的輸入是一個函數。
2.1.1. feature 信息
解析基本就是寫入時的逆過程,所以會需要寫入時的信息,這裡先列出剛才寫入時,所有 feature 的各項信息。
註:用到了 pandas,沒有的請 pip install pandas。
顯示結果
有 6 個信息,name, type, shape, isbyte, length_type, default。前 3 個好懂,這裡額外說明後 3 個:
isbyte:是用於記錄該 feature 是否字元化了。
default:是當所讀的樣本中該 feature 值缺失用什麼填補,這裡並沒有使用,所以全部都是 np.NaN
length_type:是指示讀取向量的方式是否定長,之後詳細說明。
註:這裡的信息都是在寫入時數據的原始信息。但是為了展示某些特性,這裡做了改動:
把 vector 的 shape 從 (3,) 改動成了 (1,3)
把 matrix 的 length_type 改成了 var(不定長)
2.1.2. 創建解析函數
接下就創建 parse function。
Step 1. 創建樣本解析字典
該字典存放著所有 feature 的解析方式,key 為 feature 名,value 為 feature 的解析方式。
解析方式有兩種:
定長特徵解析:tf.FixedLenFeature(shape, dtype, default_value)
shape:可當 reshape 來用,如 vector 的 shape 從 (3,) 改動成了 (1,3)。
註:如果寫入的 feature 使用了. tostring() 其 shape 就是 ()
dtype:必須是 tf.float32, tf.int64, tf.string 中的一種。
default_value:feature 值缺失時所指定的值。
不定長特徵解析:tf.VarLenFeature(dtype)
註:可以不明確指定 shape,但得到的 tensor 是 SparseTensor。
Step 2. 解析樣本
Step 3. 轉變特徵
得到的 parsed_example 也是一個字典,其中每個 key 是對應 feature 的名字,value 是相應的 feature 解析值。如果使用了下面兩種情況,則還需要對這些值進行轉變。其他情況則不用。
string 類型:tf.decode_raw(parsed_feature, type) 來解碼
註:這裡 type 必須要和當初. tostring() 化前的一致。如 tensor 轉變前是 tf.uint8,這裡就需是 tf.uint8;轉變前是 tf.float32,則 tf.float32
VarLen 解析:由於得到的是 SparseTensor,所以視情況需要用 tf.sparse_tensor_to_dense(SparseTensor) 來轉變成 DenseTensor
Step 4. 改變形狀
到此為止得到的特徵都是向量,需要根據之前存儲的 shape 信息對每個 feature 進行 reshape。
Step 5. 返回樣本
現在樣本中的所有 feature 都被正確設定了。可以根據需求將不同的 feature 進行拆分合併等處理,得到想要的輸入 和標籤 ,最終在 parse_function 末尾返回。這裡為了展示,我直接返回存有 4 個特徵的字典。
2.1.3. 執行解析函數
創建好解析函數後,將創建的 parse_function 送入 dataset.map() 得到新的數據集
2.2. 創建迭代器
有了解析過的數據集後,接下來就是獲取當中的樣本。
2.3. 獲取樣本
顯示結果,還會顯示先前保存的頭像
我們寫進 test.tfrecord 文件中了 3 個樣本,用
導入了兩次,所以有 6 個樣本。scalar 的值,也符合所寫入的數據。
2.4. Shuffle
可以輕鬆使用來打亂順序。設置成一個大於你數據集中樣本數量的值來確保其充分打亂。
註:對於數據集特別巨大的情況,請參考 YJango:tensorflow 中讀取大規模 tfrecord 如何充分 shuffle?(http://t.cn/R8y0jmQ)
順序打亂了,但 1,2,3 都出現過 2 次
2.5. Batch
再從亂序後的數據集上進行 batch。
6 個樣本,以 4 個進行 batch,第一個得到 4 個,第二個得到餘下的 2 個
2.6. Batch_padding
也可以在每個 batch 內進行 padding
padded_shapes 指定了內部數據是如何 pad 的。
rank 數要與元數據對應
rank 中的任何一維被設定成 None 或 - 1 時都表示將 pad 到該 batch 下的最大長度
顯示結果
2.7. Epoch
使用來指定要遍歷幾遍整個數據集
顯示結果
幫助函數
如果你不想設定,直接送入數據就可以讀寫。可以用下面的方法
複製 dataset_helper.py(http://t.cn/R8yOJFQ)
把所有樣本寫成一個 list
list 中的每個元素都是一個字典
字典的 key 是 feature 名
字典的 value 是 feature 值
feature 值的形狀和類型都指定好 (int64, float32)
導入幫助類
數據寫成 tfrecord
從 tfrecord 導入數據到 dataset
使用迭代器獲取樣本
YJango/TFRecord-Dataset-APIgithub.com(http://t.cn/R8yOmmX)
時間允許會結合剩下的兩個高層 API:Estimator、Experiment 作為下一篇介紹內容。
參考資料:
TensorFlow Importing Data(http://t.cn/Rl2azT8)
Tfrecords Guide(http://t.cn/R8yWPJo)
NLP 工程師入門實踐班:基於深度學習的自然語言處理
三大模塊,五大應用,手把手快速入門 NLP
海外博士講師,豐富項目經驗
演算法 + 實踐,搭配典型行業應用
隨到隨學,專業社群,講師在線答疑
AI 科技評論年度特輯


※43位頂級學術IP演講全收錄,最值得收藏的30萬字「全文+PPT」精華
TAG:AI研習社 |