兼容 Scikit-Learn的PyTorch 神經網路庫——skorch
Skorch 是一個兼容 Scikit-Learn 的 PyTorch 神經網路庫。
資源
文檔:
https://skorch.readthedocs.io/en/latest/?badge=latest
源代碼
https://github.com/dnouri/skorch/
示例
更詳細的例子,請查看此鏈接:
https://github.com/dnouri/skorch/tree/master/notebooks/README.md
import numpy as np
from sklearn.datasets import make_classification
import torch
from torch import nn
import torch.nn.functional as F
from skorch.net import NeuralNetClassifier
X, y = make_classification(1000,20, n_informative=10, random_state=)
X = X.astype(np.float32)
y = y.astype(np.int64)
classMyModule(nn.Module):
def__init__(self, num_units=10, nonlin=F.relu):
super(MyModule,self).__init__()
self.dense= nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(.5)
self.dense1 = nn.Linear(num_units,10)
self.output = nn.Linear(10,2)
defforward(self, X, **kwargs):
X =self.nonlin(self.dense(X))
X =self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X), dim=-1)
returnX
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=.1,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
In an sklearn Pipeline:
fromsklearn.pipelineimportPipeline
fromsklearn.preprocessingimportStandardScaler
pipe = Pipeline([
("scale", StandardScaler()),
("net", net),
])
pipe.fit(X, y)
y_proba = pipe.predict_proba(X)
With grid search
fromsklearn.model_selectionimportGridSearchCV
params = {
"lr": [0.01,0.02],
"max_epochs": [10,20],
"module__num_units": [10,20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring="accuracy")
gs.fit(X, y)
print(gs.best_score_, gs.best_params_)
安裝
pip 安裝
pipinstall-U skorch
建議使用虛擬環境。
源代碼安裝
如果你想使用 skorch 最新的案例或者開發幫助,請使用源代碼安裝
用 conda
如果你需要一個工作conda安裝, 從這裡為的的系統獲取正確的 miniconda:
https://conda.io/miniconda.html
如果你只是使用 skorch:
git clone https://github.com/dnouri/skorch.git
cdskorch
conda env create
sourceactivate skorch
# install pytorchversionforyoursystem(see below)
pythonsetup.pyinstall
如果你只想幫助開發,運行:
git clone https://github.com/dnouri/skorch.git
cdskorch
conda env create
sourceactivate skorch
# install pytorchversionforyoursystem(see below)
conda install --filerequirements-dev.txt
pythonsetup.pydevelop
py.test # unit tests
pylint skorch # static code checks
用 pip
如果你只是使用 skorch:
git clone https://github.com/dnouri/skorch.git
cdskorch
# createandactivateavirtual environment
pip install -r requirements.txt
# install pytorchversionforyoursystem(see below)
pythonsetup.pyinstall
如果你想使用幫助開發:
git clone https://github.com/dnouri/skorch.git
cdskorch
# createandactivateavirtual environment
pip install -r requirements.txt
# install pytorchversionforyoursystem(see below)
pip install -r requirements-dev.txt
pythonsetup.pydevelop
py.test # unit tests
pylint skorch # static code checks
從Python入門-如何成為AI工程師
BAT資深演算法工程師獨家研發課程
最貼近生活與工作的好玩實操項目
班級管理助學搭配專業的助教答疑
學以致用拿offer,學完即推薦就業
新人福利
關注 AI 研習社(okweiwu),回復1領取
【超過 1000G 神經網路 / AI / 大數據資料】
Scikit-learn(sklearn)官方文檔中文版
※OpenAI Baselines 更新,新增 HER 強化學習演算法
※2018 年,是時候來一場「體面」的人工智慧安防峰會了
TAG:AI研習社 |