當前位置:
首頁 > 新聞 > 快速開啟你的第一個項目:TensorFlow項目架構模板

快速開啟你的第一個項目:TensorFlow項目架構模板

作為最為流行的深度學習資源庫,TensorFlow 是幫助深度學習新方法走向實現的強大工具。它為大多數深度學習領域中使用的常用語言提供了大量應用程序介面。對於開發者和研究人員來說,在開啟新的項目前首先面臨的問題是:如何構建一個簡單明了的結構,本文或許可以為你帶來幫助。

項目鏈接:https://github.com/Mrgemy95/Tensorflow-Project-Template

TensorFlow 項目模板

簡潔而精密的結構對於深度學習項目來說是必不可少的,在經過多次練習和 TensorFlow 項目開發之後,本文作者提出了一個結合簡便性、優化文件結構和良好 OOP 設計的 TensorFlow 項目模板。該模板可以幫助你快速啟動自己的 TensorFlow 項目,直接從實現自己的核心思想開始。

這個簡單的模板可以幫助你直接從構建模型、訓練等任務開始工作。

目錄

  • 概述

  • 詳述

  • 項目架構

  • 文件夾結構

  • 主要組件

  • 模型

  • 訓練器

  • 數據載入器

  • 記錄器

  • 配置

  • Main

  • 未來工作

概述

簡言之,本文介紹的是這一模板的使用方法,例如,如果你希望實現 VGG 模型,那麼你應該:

在模型文件夾中創建一個名為 VGG 的類,由它繼承「base_model」類

class VGGModel(BaseModel):
def __init__(self, config):
super(VGGModel, self).__init__(config)
#call the build_model and init_saver functions.
self.build_model()
self.init_saver()

覆寫這兩個函數 "build_model",在其中執行你的 VGG 模型;以及定義 TensorFlow 保存的「init_saver」,隨後在 initalizer 中調用它們。

def build_model(self):
# here you build the tensorflow graph of any model you want and also define the loss.
pass
def init_saver(self):
#here you initalize the tensorflow saver that will be used in saving the checkpoints.
self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)

在 trainers 文件夾中創建 VGG 訓練器,繼承「base_train」類。

class VGGTrainer(BaseTrain):
def __init__(self, sess, model, data, config, logger):
super(VGGTrainer, self).__init__(sess, model, data, config, logger)

覆寫這兩個函數「train_step」、「train_epoch」,在其中寫入訓練過程的邏輯。

def train_epoch(self):
"""
implement the logic of epoch:
-loop ever the number of iteration in the config and call teh train step
-add any summaries you want using the sammary
"""
pass def train_step(self):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
pass

在主文件中創建會話,創建以下對象:「Model」、「Logger」、「Data_Generator」、「Trainer」與配置:

sess = tf.Session()
# create instance of the model you want
model = VGGModel(config) # create your data generator
data = DataGenerator(config)
# create tensorboard logger
logger = Logger(sess, config)

向所有這些對象傳遞訓練器對象,通過調用「trainer.train()」開始訓練。

trainer = VGGTrainer(sess, model, data, config, logger)
# here you train your model
trainer.train()

你會看到模板文件、一個示例模型和訓練文件夾,向你展示如何快速開始你的第一個模型。

詳述

模型架構

快速開啟你的第一個項目:TensorFlow項目架構模板

主要組件

模型

  • 基礎模型

基礎模型是一個必須由你所創建的模型繼承的抽象類,其背後的思路是:絕大多數模型之間都有很多東西是可以共享的。基礎模型包含:

  • Save-此函數可保存 checkpoint 至桌面。

  • Load-此函數可載入桌面上的 checkpoint。

  • Cur-epoch、Global_step counters-這些變數會跟蹤訓練 epoch 和全局步。

  • Init_Saver-一個抽象函數,用於初始化保存和載入 checkpoint 的操作,注意:請在要實現的模型中覆蓋此函數。

  • Build_model-是一個定義模型的抽象函數,注意:請在要實現的模型中覆蓋此函數。

  • 你的模型

以下是你在模型中執行的地方。因此,你應該:

  • 創建你的模型類並繼承 base_model 類。

  • 覆寫 "build_model",在其中寫入你想要的 tensorflow 模型。

  • 覆寫"init_save",在其中你創建 tensorflow 保存器,以用它保存和載入檢查點。

  • 在 initalizer 中調用"build_model" 和 "init_saver"

訓練器

  • 基礎訓練器

基礎訓練器(Base trainer)是一個只包裝訓練過程的抽象的類。

  • 你的訓練器

以下是你應該在訓練器中執行的。

  • 創建你的訓練器類,並繼承 base_trainer 類。

  • 覆寫這兩個函數,在其中你執行每一步和每一 epoch 的訓練過程。

數據載入器

這些類負責所有的數據操作和處理,並提供一個可被訓練器使用的易用介面。

記錄器(Logger)

這個類負責 tensorboard 總結。在你的訓練器中創建一個有關所有你想要的 tensorflow 變數的詞典,並將其傳遞給 logger.summarize()。

配置

我使用 Json 作為配置方法,接著解析它,因此寫入所有你想要的配置,然後用"utils/config/process_config"解析它,並把這個配置對象傳遞給所有其他對象。

Main

以下是你整合的所有之前的部分。

1. 解析配置文件。

2. 創建一個 TensorFlow 會話。

3. 創建 "Model"、"Data_Generator" 和 "Logger"實例,並解析所有它們的配置。

4. 創建一個"Trainer"實例,並把之前所有的對象傳遞給它。

5. 現在你可通過調用"Trainer.train()"訓練你的模型。

未來工作

未來,該項目計劃通過新的 TensorFlow 數據集 API 替代數據載入器。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

谷歌構建的「明日之城」,共享單車、網約車、AR技術都在裡面了
谷歌正式發布TensorFlow 1.5,究竟提升了哪些功能?

TAG:機器之心 |