當前位置:
首頁 > 知識 > 詳解Siamese網路

詳解Siamese網路

摘要

Siamese網路用途,原理,如何訓練?

背景

在人臉識別中,存在所謂的one-shot問題。舉例來說,就是對公司員工進行人臉識別,每個員工只給你一張照片(訓練集樣本少),並且員工會離職、入職(每次變動都要重新訓練模型)。有這樣的問題存在,就沒辦法直接訓練模型來解決這樣的分類問題了。

為了解決one-shot問題,我們會訓練一個模型來輸出給定兩張圖像的相似度,所以模型學習得到的是similarity函數。

哪些模型能通過學習得到similarity函數呢?Siamese網路就是這樣的一種模型。

Siamese網路原理

Siamese網路要給出輸入圖像X1和X2的相似度,所以它必須能接受兩個圖像作為輸入,如下圖:

詳解Siamese網路

打開今日頭條,查看更多精彩圖片

圖中上下兩個模型,都由CNN構成,兩個模型的參數值完全相同。不同於傳統CNN的地方,是Siamese網路並不直接輸出類別,而是輸出一個向量(比如上圖中是128個數值組成的一維向量):

若輸入的圖像X1和X2為同一個人,則上下兩個模型輸出的一維向量歐氏距離較小

若輸入的圖像X1和X2不是同一個人,則上下兩個模型輸出的一維向量歐氏距離較大

所以通過對上下兩個模型輸出的向量做歐氏距離計算,就能得到輸入兩幅圖像的相似度。

詳解Siamese網路

又因為上下兩個模型具有相同的參數,所以訓練模型時,只需要訓練一個模型即可。那問題來了,這樣的模型該怎麼訓練呢?模型的輸出label該標註為什麼呢?

如何訓練Siamese網路

模型的訓練,就是給定cost function後,用梯度下降法尋找最優值的過程。

訓練Siamese網路,需要引入新的cost function。我們先看模型的學習目標(下圖),再一步一步講解cost function的最終表達式。

詳解Siamese網路

對圖中的一幅照片A,如果給定了同一個人的另一幅照片P,則模型的輸出向量f(A)和f§應該是距離比較小的。如果給定了另一個人的照片N,則模型的輸出向量f(A)和f(N)之間的距離就比較小。所以d(A,P)<d(A,N)。

根據這個目標,就得到了cost function的定義:

詳解Siamese網路

其目的,是遍歷所有三元組(A,P,N),求其L的最小。公式中的參數α,是一個超參數,用於做margin,能避免模型輸出的都是零向量。

有了這個cost function,用梯度下降法就能找到模型的最優值。這個過程是不需要我們對模型的向量值進行人工標註的。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 程序員小新人學習 的精彩文章:

Linux運行級別和找回root密碼
頁面跳轉的兩種方式(轉發和重定向)區別及應用場景分析

TAG:程序員小新人學習 |