當前位置:
首頁 > 最新 > 聊一聊TensorFlow的數據導入機制

聊一聊TensorFlow的數據導入機制

聊一聊TensorFlow的數據導入機制

今天我們要講的是TensorFlow中的數據導入機制,傳統的做法是習慣於先構建好TF圖模型,然後開啟一個會話(Session),在運行圖模型之前將數據feed到圖中,這種做法的缺點是數據IO帶來的時間消耗很大,那麼在訓練非常龐大的數據集的時候,不提倡採用這種做法,TensorFlow中取而代之的是tf.data.Dataset模塊,今天我們重點介紹這個。

tf.data是一個十分強大的可以用於構建複雜的數據導入機制的API,例如,如果你要處理的是圖像,那麼tf.data可以幫助你把分布在不同位置的文件整合到一起,並且對每幅圖片添加微小的隨機雜訊,以及隨機選取一部分圖片作為一個batch進行訓練;又或者是你要處理文本,那麼tf.data可以幫助從文本中解析符號並且轉換成embedding矩陣,然後將不同長度的序列變成一個個batch。

我們可以用tf.data.Dataset來構建一個數據集,數據集的來源可以有多種方式,例如如果你的數據集是預先以TFRecord格式寫在硬碟上的,那麼你可以用tf.data.TFRecordDataset來構建;如果你的數據集是內存中的tensor變數,那麼可以用tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()來構建。下面我將通過代碼來演示它們。

首先,我們來看從內存中的tensor變數來構建數據集,如下代碼所示,首先構建了一個0~10的數據集,然後構建迭代器,迭代器可以每次從數據集中提取一個元素:

如上代碼所示,range()是tf.data.Dataset類的一個靜態函數,用於產生一段序列。需要注意的是,構建的數據集需要是同一種數據類型以及內部結構。除此之外,由於range(10)代表0~9一共十個數,因此,這裡的iterator只能運行10次,超過以後將會拋出tf.errors.OutOfRangeError異常。如果希望不拋出異常,則可以調用dataset.repeat(count)即可實現count次自動重複的迭代器。

range的範圍我們也可以在運行時才確定,即定義max_range為placeholder變數,這個時候需要調用Dataset的make_initializable_iterator方法來構建迭代器,並且這個迭代器的operation需要在迭代之前被運行,代碼如下所示:

也可以為不同的數據集創建同一個迭代器,為了使得這個迭代器可以被重複使用,需要保證不同數據集的類型和維度是一致的。例如,下面的代碼演示了如何使用同一個迭代器來構建訓練集和驗證集,可以看到,當我們開始訓練訓練集的時候,就需要先執行training_init_op,目的是使得迭代器開始載入訓練數據;而當進行驗證的時候,則需要先執行validation_init_op,道理一樣。

也可以通過Tensor變數構建tf.data.Dataset,如下代碼所示,需要注意的是,這裡的Tensor的維度是4×10,因此,傳入到迭代器中就是可以運行4次,每次運行生成一個長度為10的向量。

首先是將音頻特徵寫入到TFRecord文件之中,在語音識別中,我們最常用的兩個特徵就是MFCC和LogFBank,要寫入文件中的不僅僅是這兩個變數,還要有文本標籤Label以及特徵序列的長度sequence_legnth,這四個變數中,只有sequence_length是整數標量,其他三個都是列表格式,所以這裡對於列表使用位元組來保存,而對於標量,使用整型來保存。

寫好TFRecord以後,在讀取的時候首先需要對TFRecord格式文件進行解析,解析函數如下:

然後我們可以直接通過調用tf.data.TFRecordDataset來導入TFRecord文件列表,以及對每個文件調用parse函數進行解析,並且由於每個文件的特徵矩陣長度不一,所以需要對齊進行padding操作,最終可以獲得迭代器,代碼如下:

於是,關於TFRecord文件的讀寫就介紹完了,並且,基於TensorFlow的數據導入機制也介紹完了。

題圖:艾德華馬奈《草地上的午餐

深度學習每日摘要|堅持技術,追求原創

微信ID:deeplearningdigest

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 深度學習每日摘要 的精彩文章:

TAG:深度學習每日摘要 |