當前位置:
首頁 > 新聞 > 劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

經典預訓練模型、新型前沿研究模型是不是比較難調用?PyTorch 團隊今天發布了模型調用神器 PyTorch Hub,只需一行代碼,BERT、GPT、PGAN 等最新模型都能玩起來。

項目地址:https://pytorch.org/hub

機器學習領域,可復現性是一項重要的需求。但是,許多機器學習出版成果難以復現,甚至無法復現。隨著數量上逐年增長的出版成果,包括數以萬計的 arXiv 文章和大會投稿,對於研究的可復現性比以往更加重要了。雖然許多研究都附帶了代碼和訓練模型,儘管他們對使用者有所幫助,但仍然需要使用者自己去研究如何使用。

今天,PyTorch 團隊發布了 PyTorch Hub,一個簡單的 API 和工作流代碼庫,它為機器學習研究的復現提供了基礎構建單元。PyTorch Hub 包括預訓練模型庫,專門用來幫助研究的復現、協助新研究的開展。它同時內置支持 Colab,集成 Papers With Code 網站,並已經有廣泛的一套預訓練模型,包括分類器、分割器、生成器和 Transformer 等等。

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

研究者發布模型

PyTorch Hub 支持在 GitHub 上發布預訓練模型(定義模型結構和預訓練權重),這隻需要增加一個簡單的 hubconf.py 文件。該文件會列舉所支持的模型,以及模型需要的依賴項。

用戶可以從以下代碼倉庫找到使用案例:

  • https://github.com/pytorch/vision/blob/master/hubconf.py
  • https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/hubconf.py
  • https://github.com/facebookresearch/pytorch_GAN_zoo

現在,我們可以看看最簡單的案例,torchvision 的 hubconf.py:

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

在 torchvision,模型有以下幾部分:

  • 每個模型文件都可以獨立的執行
  • 這些模型不依賴 PyTorch 以外的包(在 hubconf.py 中以及集成了相關依賴:dependencies["torch"])
  • 這些模型不需要單獨的模型入口(entry-point),因為這些模型一經創建,就可以無縫地提取使用

減少包的依賴可以減少用戶導入模型時出現的各種問題,而且這種導入可能只是臨時的調用。一個直觀的例子是 HuggingFace"s BERT 模型。其 hubconf.py 文件如下:

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

每個模型都需要創建一個模型入口,以下指定了 bertForMaskedLM 模型入口,並希望獲得預訓練模型權重:

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

這些入口可以作為複雜模型的包裝器,我們能提供注釋文檔或額外的幫助函數。最後,有了 hubconf.py,研究者就能發送 pull 請求。當 PyTorch 接受了該請求後,研究者的模型就會出現在 PyTorch Hub 頁面上。

用戶工作流

PyTorch Hub 允許用戶只用簡單的幾步就完成很多任務,例如 1)探索可用模型;2)載入預訓練模型;3)理解載入模型的方法與運行參數。下面讓我們通過一些案例體會體會 PyTorch Hub 的便捷吧。

探索可用模型

我們可以使用 torch.hub.list() API 查看倉庫內所有可用的模型。

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

注意,PyTorch 還允許使用預訓練模型之外的輔助模塊,例如使用 bertTokenizer 來完成 BERT 模型的預處理過程,它們都會使工作流更加順暢。

載入模型

現在我們已經知道有哪些預訓練模型,下面就可以使用 torch.hub.load() API 載入這些模型了。使用 API 載入模型時,它只需要一行命令,而不需要額外安裝 wheel。另外,torch.hub.help() API 也能提供一些有用的信息來幫助演示如何使用預訓練模型。

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

其實這些預訓練模型會經常更新,不論是修復 Bug 還是提升性能。而 PyTorch Hub 令用戶可以極其簡單地獲取最後的更新版:

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

PyTorch 團隊相信這個特性能幫助預訓練模型的擁有者減輕負擔,即重複發布包的成本會降低,他們也能更加專註於研究(預訓練模型)本身。此外,該特性對用戶也有很大優勢,我們可以快速獲得最新的預訓練模型。

另一方面,穩定性對於用戶而言非常重要。因此,模型提供者能以特定的分支或 Tag 為用戶提供支持,而不直接在 master 分支上提供。這種方式能確保代碼的穩定性,例如 pytorch_GAN_zoo 可以用 hub 分支來支持對其模型的使用。

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

注意傳遞到 hub.load() 中的 args 和 kwargs,它們都用於實例化模型。在上面的例子中,pretrained=True 和 useGPU=False 都被賦予不同的預訓練模型。

探索已載入模型

當我們從 PyTorch Hub 中載入了模型時,我們能從以下工作流探索可用的方法,並更好地理解運行它們需要什麼樣的參數。

dir(model) 方法可以查看模型的所有方法,下面我們可以看看 bertForMaskedLM 模型的可用方法。

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

help(model.forward) 方法將提供要令模型能正常跑,其所需要的參數。

下面提供了 BERT 和 DeepLabV3 兩個例子,我們可以看看這些模型載入後都能怎樣使用。

  • BERT:https://pytorch.org/hub/huggingface_pytorch-pretrained-bert_bert/
  • DeepLabV3:https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/

PyTorch Hub 中的可用模型同樣支持 Colab,它們都會連接到 Papers With Code 網站。

TensorFlow 和 PyTorch 你選誰?

此前 TensorFlow 也發布了 TensorFlow Hub,它同樣用於發布、探索和使用機器學習模型中可復用的部分。最近關注便捷性的 TensorFlow 2.0 Beta 也已經發布,但很多讀者還是傾向於使用 PyTorch。既然這兩大框架越來越「相似」,那麼我們到底該使用哪個?下面機器之心簡要總結了這兩個深度學習框架的發展歷程,我們也相信,用哪個都能開發出想要的炫酷應用。

TensorFlow:

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

PyTorch:

劍指TensorFlow,PyTorch Hub官方模型庫一行代碼復現主流模型

TensorFlow 和 PyTorch 都是經典的機器學習代碼庫。隨著學界和工業界對機器學習的需求的增長,兩者的社區也在不斷發展壯大。雖然 TensorFlow 是老牌的機器學習代碼庫,但由於 1.x 及之前版本存在的諸多問題,許多用戶逐漸轉向對用戶友好、學習門檻低、使用方便的 PyTorch。在 2018 年,TensorFlow 逐漸意識到這一問題,並在 2.x 版本逐漸提升了用戶體驗。

與此同時,基於兩個經典機器學習代碼庫的進一步工具開發也是近年來的趨勢。過去有部分基於 TensorFlow 的 Keras 和基於 PyTorch 的 fast.ai,最近一兩年則有大量的模型庫和方便用戶快速訓練和部署模型的代碼庫,如 Tensor2Tensor,以及針對特定領域的代碼庫,如基於 PyTorch 的 NLP 代碼庫 PyText 和圖神經網路庫 PyG。

目前來看,TensorFlow 的生態系統更為多樣和完善,且具有多語言的支持,其廣受詬病的難以使用的缺點也在逐漸改善。另一方面,由於 PyTorch 本身用戶友好的特性,基於這一代碼庫的應用開發進度似乎也趕上了 TensorFlow,儘管在多語言支持等方面 PyTorch 依然有較大差距。這一機器學習生態之戰究竟會走向何方,目前尚不明朗。未來的機器學習代碼框架的發展趨勢是,模型的訓練、部署工作量將會越來越低,類似「搭積木」方式的應用部署方式將會越來越流行。研究者和開發者的精力將會完全轉向模型結構的設計、部署和完善,而非糾結於框架的選擇和其他底層工程問題上。

參考鏈接:https://pytorch.org/blog/towards-reproducible-research-with-pytorch-hub/

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

只要戴上這副眼鏡,你也可以去歌神演唱會抓逃犯了
AMD停止授權中國x86新技術,「晶元國產化」路子怎麼走?

TAG:機器之心 |