當前位置:
首頁 > 知識 > 手把手教你用 PyTorch 辨別自然語言

手把手教你用 PyTorch 辨別自然語言

AI 研習社按:本文作者甄冉冉,原載於作者個人博客,雷鋒網 AI 研習社已獲授權。

最近在學pyTorch的實際應用例子。這次說個簡單的例子:給定一句話,判斷是什麼語言。這個例子是比如給定一句話:

Give it to me

判斷是 ENGLISH

me gusta comer en la cafeteria

判斷是 SPANISH

就是這麼簡單的例子。

來看怎麼實現:

準備數據 格式 [(語句,類型),...]

data是train的時候用的語句,test_data是test的時候用的語句

data = [ ("me gusta comer en la cafeteria".split(), "SPANISH"),

("Give it to me".split(), "ENGLISH"),

("No creo que sea una buena idea".split(), "SPANISH"),

("No it is not a good idea to get lost at sea".split(), "ENGLISH") ]

test_data = [("Yo creo que si".split(), "SPANISH"),

("it is lost on me".split(), "ENGLISH")]

因為文本計算機室識別不出來的,他們只認識01串,也就是數字。所以我們得把文本映射到數字上。

word_to_ix = {}

for sent, _ in data + test_data:

for word in sent:

if word not in word_to_ix:

word_to_ix[word] = len(word_to_ix)

print(word_to_ix)

輸出word_to_ix (意思是word to index)是:

{ me : 0, gusta : 1, comer : 2, en : 3, la : 4, cafeteria : 5, Give : 6, it : 7, to : 8, No : 9, creo : 10, que : 11, sea : 12, una : 13, buena : 14, idea : 15, is : 16, not : 17, a : 18, good : 19, get : 20, lost : 21, at : 22, Yo : 23, si : 24, on : 25}

這裡先提前設置下接下來要用到的參數

VOCAB_SIZE = len(word_to_ix)

NUM_LABELS = 2#只有兩類 ENGLISH SPANISH

固定模板

def init(self, num_labels, vocab_size):初始化,就是輸入和輸出的大小。這裡我們要輸入是一個句子,句子最大就是擁有所有字典的詞,這裡也就是vocab_size(下面再說怎麼將一句話根據字典轉換成一個數字序列的),輸出就是分類,這裡分為2類,即num_labels。這裡我們用的是線性分類 ,即nn.Linear()。

def forward(self, bow_vec):bow_vec是一個句子的數字化序列,經過self.linear()得到一個線性結果(也就是預測結果),之後對這個結果進行softmax(這裡用log_softmax是因為下面的損失函數用的是NLLLoss() 即負對數似然損失,需要log以下)

class BoWClassifier(nn.Module):#nn.Module 這是繼承torch的神經網路模板

def __init__(self, num_labels, vocab_size):

super(BoWClassifier, self).__init__()

self.linear = nn.Linear(vocab_size, num_labels)

def forward(self, bow_vec):

return F.log_softmax(self.linear(bow_vec))

def make_bow_vector(sentence, word_to_ix)

大概能看懂什麼意思吧。就是把一個句子sentence通過word_to_ix轉換成數字化序列.比如 sentence=我 是 一隻 小 小 鳥 word_to_id= make_bow_vector之後的結果是[0,1,0,0,1,0,2,0,1]。view()就是改變下向量維數。

這裡是講len(word_to_ix)1->1len(word_to_ix)

def make_bow_vector(sentence, word_to_ix):

vec = torch.zeros(len(word_to_ix))

for word in sentence:

vec[word_to_ix[word]] += 1

return vec.view(1, -1)

這個就不用說了吧 一樣。(如果想知道torch.LongTensor啥意思的話。可以看看。Torch中,Tensor主要有ByteTensor(無符號char),CharTensor(有符號),ShortTensor(shorts), IntTensor(ints), LongTensor(longs), FloatTensor(floats), DoubleTensor(doubles),默認存放為double類型,如果需要特別指出,通過torch.setdefaulttensortype()方法進行設定。例如torch.setdefaulttensortype(『torch.FloatTensor』)。 )

def make_target(label, label_to_ix):

return torch.LongTensor([label_to_ix[label]])

這裡再介紹下model.parameters()這個函數。他的返回結果是model里的所有參數。這裡我們用的是線性函數,所以就是f(x)=Ax+b中的A和b(x即輸入的數據),這些參數在之後的反饋和更新參數需要的。

model = BoWClassifier(NUM_LABELS, VOCAB_SIZE)

for param in model.parameters():

print("param:", param)

可以看出A是2len(vocab_size),b是21

param: Parameter containing:

Columns 0 to 9

0.0786 0.1596 0.1259 0.0054 0.0558 -0.0911 -0.1804 -0.1526 -0.0287 -0.1086

-0.0651 -0.1096 -0.1807 -0.1907 -0.0727 -0.0179 0.1530 -0.0910 0.1943 -0.1148

Columns 10 to 19

0.0452 -0.0786 0.1776 0.0425 0.1194 -0.1330 -0.1877 -0.0412 -0.0269 -0.1572

-0.0361 0.1909 0.1558 0.1309 0.1461 -0.0822 0.1078 -0.1354 -0.1877 0.0184

Columns 20 to 25

0.1818 -0.1401 0.1118 0.1002 0.1438 0.0790

0.1812 -0.1414 -0.1876 0.1569 0.0804 -0.1897

[torch.FloatTensor of size 2x26]

param: Parameter containing:

0.1859

0.1245

[torch.FloatTensor of size 2]

我們再看看model的def forward(self, bow_vec):怎麼用。這裡就想下面的代碼一樣,直接在mode()填一個參數即可,就調用forward函數。

sample = data[0]

bow_vector = make_bow_vector(sample[0], word_to_ix)

log_probs = model(autograd.Variable(bow_vector))

print("log_probs", log_probs)

輸出是:(就是log_softmax後的值)

log_probs Variable containing:

-0.6160 -0.7768

[torch.FloatTensor of size 1x2]

我們這裡看看在test上的預測

label_to_ix = { "SPANISH": 0, "ENGLISH": 1 }

for instance, label in test_data:

bow_vec = autograd.Variable(make_bow_vector(instance, word_to_ix))

log_probs = model(bow_vec)

print log_probs

print next(model.parameters())[:,word_to_ix["creo"]]

結果是

Variable containing:

-0.5431 -0.8698

[torch.FloatTensor of size 1x2]

Variable containing:

-0.7405 -0.6480

[torch.FloatTensor of size 1x2]

Variable containing:

-0.0467

0.1065

[torch.FloatTensor of size 2]

下面就該進行重要的部分了。

循環訓練和更新參數

這裡我們用的損失函數是nn.NLLLoss()負對數似然損失,優化依然用的最常見的optim.SGD() 梯度下降法,一般訓練5-30次最終優化基本不再變化。

每一步過程:

a.首先都要model.zero_grad(),因為接下來要極端梯度,得清零,以防問題

b.將數據向量化(也可以說是數字序列化,轉成計算機能看懂的形式)

c.得到預測值

d.求損失loss_function

e.求梯度loss.backward()

f.更新參數optimizer.step()

loss_function = nn.NLLLoss()

optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(100):

for instance, label in data:

model.zero_grad()

bow_vec = autograd.Variable(make_bow_vector(instance, word_to_ix))

target = autograd.Variable(make_target(label, label_to_ix))

log_probs = model(bow_vec)

loss = loss_function(log_probs, target)

loss.backward()

optimizer.step()

在測試集上測試

for instance, label in test_data:

bow_vec = autograd.Variable(make_bow_vector(instance, word_to_ix))

log_probs = model(bow_vec)

print log_probs

我們在結果上很容易看到第一個例子預測是SPANISH最大,第二個是ENGLISH最大。成功了。

Variable containing:

-0.0842 -2.5161

[torch.FloatTensor of size 1x2]

Variable containing:

-2.4886 -0.0867

[torch.FloatTensor of size 1x2]

開發者專場 | 英偉達深度學習學院現場授課

學習形式:線下授課 + 交流答疑

時間:7 月 8 日

地點:深圳市福田區福華路大中華喜來登酒店

培訓價格:1999 元,前五十名報名者提供五折早鳥票,先到先得!

點擊展開全文

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 唯物 的精彩文章:

Yann LeCun 最新研究成果:可以幫助 GAN 使用離散數據的 ARAE
一文教你如何用神經網路識別驗證碼!
谷歌開源物體檢測系統 API
如何實現模擬人類視覺注意力的循環神經網路?
一文詳解如何用 R 語言繪製熱圖

TAG:唯物 |

您可能感興趣

如何辨別iPhone手機的真偽?
怎麼辨別iPhone手機的真偽?
Natures蘭葹教你辨別流言,告別護膚套路!
辨別iPhone XS手機真假的幾種方法
Longines浪琴錶如何辨別真偽
如何使用netstat命令辨別DDOS入侵
lars larsen手錶真假辨別 四個辨別要點須知
Cartier卡地亞手錶藍氣球如何辨別真偽
lachinata希那塔:教你如何辨別真假橄欖油
你的YEEZY買假了?球鞋鑒定大師The Shoe Surgeon教你辨別假鞋!
卡地亞手錶paris系列,請高手辨別真假
Goyard購物袋真假辨別:要說lv nf是購物袋鼻祖,Goyard是一萬個大寫滴不服!
教你如何辨別AirJordan 1 「Top 3 」鴛鴦真假
Google Photos 現在可以辨別貓、狗的品種了
iPhone XS如何辨別真偽?四個真假辨別技巧教給你!
用戶可在Chrome瀏覽器上辨別某項目是否為騙局
白色版Virgil Abloh x Nike VaporMax發售之際,雙手奉上Fake Or Real辨別指南!
iPhone真假辨別——數據線
如何辨別自己購買的 iPhone 是不是翻新機?
PPmoney開放日八字箴言 教你如何辨別真假P2P