深度學習工程模板:簡化載入數據、構建網路、訓練模型和預測樣本的流程
使用方式
下載工程
創建和激活虛擬環境
安裝Python依賴庫
開發流程
定義自己的數據載入類,繼承DataLoaderBase;
定義自己的網路結構類,繼承ModelBase;
定義自己的模型訓練類,繼承TrainerBase;
定義自己的樣本預測類,繼承InferBase;
定義自己的配置文件,寫入實驗的相關參數;
執行訓練模型和預測樣本操作。
示例工程
識別MNIST庫中手寫數字,工程
訓練:
預測:
網路結構
TensorBoard
工程架構
框架圖
文件夾結構
主要組件
DataLoader
操作步驟:
創建自己的載入數據類,繼承DataLoaderBase基類;
覆寫和,返回訓練和測試數據;
Model
操作步驟:
創建自己的網路結構類,繼承ModelBase基類;
覆寫,創建網路結構;
在構造器中,調用;
注意:支持繪製網路結構;
Trainer
操作步驟:
創建自己的訓練類,繼承TrainerBase基類;
參數:網路結構model、訓練數據data;
覆寫,fit數據,訓練網路結構;
注意:支持在訓練中調用callbacks,額外添加模型存儲、TensorBoard、FPR度量等。
Infer
操作步驟:
創建自己的預測類,繼承InferBase基類;
覆寫,提供模型載入功能;
覆寫,提供樣本預測功能;
Config
定義在模型訓練過程中所需的參數,JSON格式,支持:學習率、Epoch、Batch等參數。
Main
訓練:
創建配置文件config;
創建數據載入類dataloader;
創建網路結構類model;
創建訓練類trainer,參數是訓練和測試數據、模型;
執行訓練類trainer的train();
預測:
創建配置文件config;
處理預測樣本test;
創建預測類infer;
執行預測類infer的predict();
原文:https://github.com/SpikeKing/DL-Project-Template
- 加入人工智慧學院系統學習 -
※摩根大通研發新的區塊鏈技術平台,CEO被打臉!
※基於PCA降維和BP神經網路的人臉識別
TAG:AI講堂 |