當前位置:
首頁 > 知識 > 兼容 Scikit-Learn的PyTorch 神經網路庫——skorch

兼容 Scikit-Learn的PyTorch 神經網路庫——skorch

Skorch 是一個兼容 Scikit-Learn 的 PyTorch 神經網路庫。

資源

文檔:

https://skorch.readthedocs.io/en/latest/?badge=latest

源代碼

https://github.com/dnouri/skorch/

示例

更詳細的例子,請查看此鏈接:

https://github.com/dnouri/skorch/tree/master/notebooks/README.md

import numpy as np

from sklearn.datasets import make_classification

import torch

from torch import nn

import torch.nn.functional as F

from skorch.net import NeuralNetClassifier

X, y = make_classification(1000,20, n_informative=10, random_state=)

X = X.astype(np.float32)

y = y.astype(np.int64)

classMyModule(nn.Module):

def__init__(self, num_units=10, nonlin=F.relu):

super(MyModule,self).__init__()

self.dense= nn.Linear(20, num_units)

self.nonlin = nonlin

self.dropout = nn.Dropout(.5)

self.dense1 = nn.Linear(num_units,10)

self.output = nn.Linear(10,2)

defforward(self, X, **kwargs):

X =self.nonlin(self.dense(X))

X =self.dropout(X)

X = F.relu(self.dense1(X))

X = F.softmax(self.output(X), dim=-1)

returnX

net = NeuralNetClassifier(

MyModule,

max_epochs=10,

lr=.1,

)

net.fit(X, y)

y_proba = net.predict_proba(X)

In an sklearn Pipeline:

fromsklearn.pipelineimportPipeline

fromsklearn.preprocessingimportStandardScaler

pipe = Pipeline([

("scale", StandardScaler()),

("net", net),

])

pipe.fit(X, y)

y_proba = pipe.predict_proba(X)

With grid search

fromsklearn.model_selectionimportGridSearchCV

params = {

"lr": [0.01,0.02],

"max_epochs": [10,20],

"module__num_units": [10,20],

}

gs = GridSearchCV(net, params, refit=False, cv=3, scoring="accuracy")

gs.fit(X, y)

print(gs.best_score_, gs.best_params_)

安裝

pip 安裝

pipinstall-U skorch

建議使用虛擬環境。

源代碼安裝

如果你想使用 skorch 最新的案例或者開發幫助,請使用源代碼安裝

用 conda

如果你需要一個工作conda安裝, 從這裡為的的系統獲取正確的 miniconda:

https://conda.io/miniconda.html

如果你只是使用 skorch:

git clone https://github.com/dnouri/skorch.git

cdskorch

conda env create

sourceactivate skorch

# install pytorchversionforyoursystem(see below)

pythonsetup.pyinstall

如果你只想幫助開發,運行:

git clone https://github.com/dnouri/skorch.git

cdskorch

conda env create

sourceactivate skorch

# install pytorchversionforyoursystem(see below)

conda install --filerequirements-dev.txt

pythonsetup.pydevelop

py.test # unit tests

pylint skorch # static code checks

用 pip

如果你只是使用 skorch:

git clone https://github.com/dnouri/skorch.git

cdskorch

# createandactivateavirtual environment

pip install -r requirements.txt

# install pytorchversionforyoursystem(see below)

pythonsetup.pyinstall

如果你想使用幫助開發:

git clone https://github.com/dnouri/skorch.git

cdskorch

# createandactivateavirtual environment

pip install -r requirements.txt

# install pytorchversionforyoursystem(see below)

pip install -r requirements-dev.txt

pythonsetup.pydevelop

py.test # unit tests

pylint skorch # static code checks

從Python入門-如何成為AI工程師

BAT資深演算法工程師獨家研發課程

最貼近生活與工作的好玩實操項目

班級管理助學搭配專業的助教答疑

學以致用拿offer,學完即推薦就業

新人福利

關注 AI 研習社(okweiwu),回復1領取

【超過 1000G 神經網路 / AI / 大數據資料】

Scikit-learn(sklearn)官方文檔中文版


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI研習社 的精彩文章:

OpenAI Baselines 更新,新增 HER 強化學習演算法
2018 年,是時候來一場「體面」的人工智慧安防峰會了

TAG:AI研習社 |