當前位置:
首頁 > 科技 > 谷歌今天又開源了,這次是Sketch-RNN

谷歌今天又開源了,這次是Sketch-RNN



谷歌今天又開源了,這次是Sketch-RNN



前不久,谷歌公布了一項最新技術,可以教機器畫畫。今天,谷歌開源了代碼。在我們研究其代碼之前,首先先按要求設置Magenta環境。(https://github.com/tensorflow/magenta/blob/master/README.md)


本文詳細解釋了Sketch-RNN的TensorFlow代碼,即之前發布的兩篇文章《Teaching Machines to Draw》和《A Neural Representation of Sketch Drawings》中描述的循環神經網路模型(RNN)。


模型概覽

sketch-rnn是序列到序列的變體自動編碼器。編碼器RNN是雙向RNN,解碼器是自回歸混合密度RNN。你可以使用enc_model,dec_model,enc_size,dec_size設置指定要使用的RNN單元格的類型和RNN的大小。


編碼器將採用一個潛在代碼z,一個維度為z_size的浮點矢量。像VAE一樣,我們可以對z強制執行高斯IID分布,並使用kl_weight來控制KL發散損失項的強度。KL散度損失與重建損失之間將會有一個權衡。我們還允許潛在的代碼存儲信息的一些空間,而不是純高斯IID。一旦KL損失期限低於kl_tolerance,我們將停止對該期限的優化。



谷歌今天又開源了,這次是Sketch-RNN



對於中小型數據集,丟失(dropout)和數據擴充是避免過度擬合的非常有用的技術。我們提供了輸入丟失、輸出丟失、不存在內存丟失的循環丟失三個選項。實際上,我們只使用循環丟失,通常根據數據集將其設置在65%到90%之間。層次歸一化和反覆丟失可以一起使用,形成了一個強大的組合,用於在小型數據集上訓練循環神經網路。


谷歌提供了兩種數據增強技術。第一個是隨機縮放訓練圖像大小的random_scale_factor。第二種增加技術(sketch-rnn論文中未使用)剔除線筆劃中的隨機點。給定一個具有超過2點的線段,我們可以隨機放置線段內的點,並且仍然保持類似的矢量圖像。這種類型的數據增強在小數據集上使用時非常強大,並且對矢量圖是唯一的,因為難以在文本或MIDI數據中刪除隨機字元或音符,並且也不可能在像素圖像數據中丟棄隨機像素而不引起大的視覺差異。我們通常將數據增加參數設置為10%至20%。如果在與普通示例相比較的情況下,人類觀眾幾乎沒有差異,那麼我們應用數據增強技術,而不考慮訓練數據集的大小。


有效地使用丟棄和數據擴充,可以避免過度擬合到一個小的訓練集。


訓練模型


要訓練模型,首先需要一個包含訓練/驗證/測試例子的數據集。我們提供了指向aaron_sheep數據集的鏈接,默認情況下,該模型將使用此輕量級數據集。


使用示例:

sketch_rnn_train --log_root=checkpoint_path --data_dir=dataset_path --hparams={"data_set"="dataset_filename.npz"}


我們建議你在模型和數據集內部創建子目錄,以保存自己的數據和檢查點。 TensorBoard日誌將存儲在checkpoint_path內,用於查看訓練/驗證/測試數據集中各種損失的訓練曲線。


以下是模型的完整選項列表以及默認設置:


data_set="aaron_sheep.npz", # Our dataset.


save_every=500, # Number of batches percheckpoint creation.


dec_rnn_size=512, # Size of decoder.


dec_model="lstm", # Decoder: lstm, layer_norm orhyper.


enc_rnn_size=256, # Size of encoder.


enc_model="lstm", # Encoder: lstm, layer_norm orhyper.


z_size=128, # Size of latent vector z.Recommend 32, 64 or 128.

kl_weight=0.5, # KL weight of loss equation.Recommend 0.5 or 1.0.


kl_weight_start=0.01, # KL start weight when annealing.


kl_tolerance=0.2, # Level of KL loss at which to stopoptimizing for KL.


batch_size=100, # Minibatch size. Recommendleaving at 100.


grad_clip=1.0, # Gradient clipping. Recommendleaving at 1.0.


num_mixture=20, # Number of mixtures in Gaussianmixture model.


learning_rate=0.001, # Learning rate.


decay_rate=0.9999, # Learning rate decay per minibatch.


kl_decay_rate=0.99995, # KL annealing decay rate per minibatch.


min_learning_rate=0.00001, # Minimum learning rate.

use_recurrent_dropout=True, # Recurrent Dropout without Memory Loss.Recomended.


recurrent_dropout_prob=0.90, # Probabilityof recurrent dropout keep.


use_input_dropout=False, # Input dropout. Recommend leaving False.


input_dropout_prob=0.90, # Probability of input dropout keep.


use_output_dropout=False, # Output droput. Recommend leaving False.


output_dropout_prob=0.90, # Probability of output dropout keep.


random_scale_factor=0.15, # Random scaling data augmentionproportion.


augment_stroke_prob=0.10, # Point dropping augmentation proportion.


conditional=True, # If False, use decoder-only model.


以下是一些可能需要用於在非常大的數據集上訓練模型的選項,並使用HyperLSTM作為RNN單元。對於小於10K的訓練樣本的小數據集,具有層規範化(包括enc_model和dec_model的layer_norm)的LSTM效果最佳。

sketch_rnn_train --log_root=models/big_model --data_dir=datasets/big_dataset --hparams={"data_set"="big_dataset_filename.npz","dec_model":"hyper","dec_rnn_size":2048,"enc_model":"layer_norm","enc_rnn_size":512,"save_every":5000,"grad_clip":1.0,"use_recurrent_dropout":0}


對於Python 2.7,我們已經在TensorFlow 1.0和1.1上測試了這個模型。


數據集


由於大小限制,此報告不包含任何數據集。


我們已經準備好了許多使用Sketch-RNN開箱即用的數據集。Google QuickDraw數據集(https://quickdraw.withgoogle.com/data)是涵蓋345個類別的50M矢量草圖的集合。在quickdraw數據集中,有一個名為Sketch-RNNQuickDraw Dataset的部分描述了可用於此項目的預處理數據文件。每個類別類都存儲在其自己的文件中,如cat.npz,並包含70000/2500/2500示例的訓練/驗證/測試集大小。


從Google雲(https://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn)


下載.npz數據集,以供本地使用。我們建議你創建一個名為datasets / quickdraw的子目錄,並將這些.npz文件保存在此子目錄中。


除了QuickDraw數據集之外,我們還在較小的數據集上測試了該模型。在sketch-rnn-datasets(https://github.com/hardmaru/sketch-rnn-datasets)報告中,還有3個數據集:AaronKoblin Sheep Market、Kanji和Omniglot。如果你希望在本地使用它們,我們建議你為每個數據集創建一個子目錄,如datasets/ aaron_sheep。如前所述,在小型數據集上訓練模型以避免過度擬合時,應使用循環退出和數據增加。


創建自己的數據集


請創建你自己有趣的數據集並訓練這些演算法!創建新的數據集是樂趣的一部分。你很可能發現有趣的矢量線圖數據集,為什麼要用現有的預先打包好的數據集呢?在我們的實驗中,由幾千個例子組成的數據集大小足以產生一些有意義的結果。在這裡,我們描述模型期望看到的數據集文件的格式。

數據集中的每個示例都存儲為坐標偏移的列表:Δx,Δy用來二進位值表示筆是否從紙張提起。這種格式,我們稱之為stroke-3,在論文中有描述(https://arxiv.org/abs/1308.0850)。 請注意,論文中描述的數據格式有5個元素(stroke-5格式),此轉換在DataLoader內自動完成。以下是使用以下格式的烏龜示例草圖:



谷歌今天又開源了,這次是Sketch-RNN



圖:作為(Δx,Δy,二進位筆狀態)序列的示例草圖點和渲染形式。在渲染草圖中,線條顏色對應於順序筆畫排列。


在我們的數據集中,示例列表中的每個示例都用np.int16數據類型表示為np.array。你可以將它們存儲為np.int8,你可以將其存儲起來以節省存儲空間。如果你的數據必須是浮點格式,也可以使用np.float16。np.float32可能會浪費存儲空間。在我們的數據中,Δx和Δy偏移通常用像素位置表示,它們大於神經網路模型喜歡看到的數字範圍,所以在模型中內置了歸一化縮放過程。當我們載入訓練數據時,模型將自動轉換為np.float並在訓練前相應規範化。


如果要創建自己的數據集,則必須為訓練/驗證/測試集創建三個示例列表,以避免過度擬合到訓練集。該模型將使用驗證集來處理早期停止。對於aaron_sheep數據集,我們使用了7400/300/300的示例,並將每個內容放在python列表中,名為train_data,validation_data和test_data。之後,我們創建了一個名為datasets / aaron_sheep的子目錄,我們使用內置的savez_compressed方法將數據集的壓縮版本保存在aaron_sheep.npz文件中。在我們的所有實驗中,每個數據集的大小是100的確切倍數。


我們還通過執行簡單的筆畫簡化來預處理數據,稱為Ramer-Douglas-Peucker。 在這裡應用這個演算法有一些易於使用的開源代碼(https://github.com/fhirschmann/rdp)。 實際上,我們可以將epsilon參數設置為0.2到3.0之間的值,具體取決於我們想要簡單的線條。 在本文中,我們使用了一個2.0的epsilon參數。 我們建議你建立最大序列長度小於250的數據集。


如果你有大量簡單的SVG圖像,則可以使用一些可用的庫(https://pypi.python.org/pypi/svg.path)來將SVG的子集轉換為線段,然後可以在將數據轉換為stroke-3格式之前對線段應用RDP。


預訓練模型


我們為aaron_sheep數據集提供了預先訓練的模型,用於條件和無條件訓練模式,使用vanilla LSTM單元以及帶有層規範化的LSTM單元。這些型號將通過運行Jupyter Notebook下載。它們存儲在:

/tmp/sketch_rnn/models/aaron_sheep/lstm


/tmp/sketch_rnn/models/aaron_sheep/lstm_uncond


/tmp/sketch_rnn/models/aaron_sheep/layer_norm


/tmp/sketch_rnn/models/aaron_sheep/layer_norm_uncond


此外,我們為選定的QuickDraw數據集提供了預先訓練的模型:


/tmp/sketch_rnn/models/owl/lstm


/tmp/sketch_rnn/models/flamingo/lstm_uncond


/tmp/sketch_rnn/models/catbus/lstm


/tmp/sketch_rnn/models/elephantpig/lstm


使用Jupyter notebook的模型


谷歌今天又開源了,這次是Sketch-RNN



讓我們來模擬貓和公車之間的插值!


我們涵蓋了一個簡單的Jupyter notebook(http://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn.ipynb),向你展示如何載入預先訓練的模型並生成矢量草圖。你能夠在兩個矢量圖像之間進行編碼,解碼和變形,並生成新的隨機圖像。採樣圖像時,可以調整temperature參數來控制不確定度。


來源:


https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器人圈 的精彩文章:

深度對抗學習整裝待發,或將改變傳統AI格局
谷歌欲用FHIR進行精準醫療預測,「AI+醫療」時代踏步前進
Google重磅推出第二代TPU,即將進入雲
詳解谷歌AutoML演算法——神經網路是如何「自我升級」的?
谷歌剛剛發布的TensorFlow研究雲到底是什麼?

TAG:機器人圈 |

您可能感興趣

AutoML又一利器來了,谷歌宣布開源AdaNet
Facebook Mask R-CNN2Go已開源
Facebook開源Mask R-CNN的PyTorch 1.0基準,更快、更省內存
LeCun:30年前知道DeepFake,我還該不該開源CNN?
8.27 VR掃描:Oculus創始人Palmer Luckey:下半年開源VR暈動症解決方案
Facebook開源Mask R-CNN的PyTorch 1.0基準,比mmdetection更快、更省內存
ONF宣布啟動下一代開源SDN交換平台Stratum
業界 | Facebook開源Mask R-CNN的PyTorch 1.0基準,比mmdetection更快、更省內存
ArXiv最受歡迎開源深度學習框架榜:TensorFlow第一,PyTorch第四
上交大盧策吾團隊開源 AlphaPose, 在 MSCOCO 上穩超 Mask-RCNN 8 個百分點
上交大盧策吾團隊開源 AlphaPose,在MSCOCO 上穩超 Mask-RCNN 8 個百分點
開始使用 Sandstorm 吧,一個開源 Web 應用平台
終結谷歌每小時20美元的AutoML!開源的AutoKeras了解下
GitHub被收購,Stack Overflow在裁員:後開源時代,開源的未來往哪邊?
ArXiv最受歡迎開源深度學習框架榜單:TensorFlow第一,PyTorch第四
「CVPR Oral」TensorFlow實現StarGAN代碼全部開源,1天訓練完
直接拿來用!谷歌開源網路庫TensorNetwork,GPU處理提升100倍
Karpathy更新深度學習開源框架排名:TensorFlow第一,PyTorch第二
9月GitHub上面Python排名前十得到開源項目,建議收藏學習
谷歌開源空間音頻Resonance Audio,加速AR-VR普及