當前位置:
首頁 > 知識 > 基於線性SVM的CIFAR-10圖像集分類

基於線性SVM的CIFAR-10圖像集分類

1. SVM的基本思想

簡單來說,支持向量機SVM就是在特徵空間中找到一條最佳的分類超平面,能夠讓正、負樣本距離該超平面的間隔(margin)最大化。

以二維平面為例,確定一條直線對正負樣本進行分類,如下圖所示:

基於線性SVM的CIFAR-10圖像集分類

很明顯,雖然分類線H1、H2、H3都能夠將正負樣本完全分開,但是毫無疑問H3更好一些。原因是正負樣本距離H3都足夠遠,即間隔「margin」最大。這就是SVM的基本思想:盡量讓所有樣本距離分類超平面越遠越好。

2. 線性分類與得分函數

在線性分類器演算法中,輸入為x,輸出為y,令權重係數為W,常數項係數為b。我們定義得分函數s為:

s=Wx+b

s=Wx+b

這是線性分類器的一般形式,得分函數s所屬類別值越大,表示預測該類別的概率越大。

以圖像識別為例,共有3個類別「cat,dog,ship」。令輸入x的特徵維度為4「即包含4個像素值」,W的維度是3x4,b的維度是3x1。在W和b確定後,得到各個類別的得分函數s為:

基於線性SVM的CIFAR-10圖像集分類

由上圖可知,因為總有3個類別,得分函數s是3x1的向量。其中,cat score=-96.8,dog score=437.9,ship score=61.95。從s的值來說,dog score最高,cat score最低,則預測為狗的概率更大一些。而該圖片真實標籤是一隻貓,顯然,從得分函數s上來看,該線性分類器的預測結果是錯誤的。

通常為了簡化計算,我們直接將W和b整合成一個矩陣,同時將x額外增加一個全為1的維度。這樣,得分函數s的表達式得到了簡化:

W:=[W b]

W:=[W b]

x:=[x; 1]

x:=[x; 1]

s=Wx

s=Wx

示例圖如下:

基於線性SVM的CIFAR-10圖像集分類

3. 優化策略與損失函數

通常來說,SVM的優化策略是樣本到分類超平面的距離最大化。也就是說盡量讓正負樣本距離分類超平面有足夠寬的間隔,這是基於距離的衡量優化方式。針對上文提到的例子,圖片真實標籤是一隻貓,但是得到的s值卻是最低的,顯然這不是我們希望看到的。最好的情況應該是cat score最高。這樣才能保證預測cat的概率更大。此時,利用SVM的間隔最大化的思想,就要求cat score不僅僅要大於其它類別的s值,而且要達到一定的程度,可以說有個最低閾值。

因此,這種新的SVM優化策略可以這樣理解:正確類別對應的得分函數s應該比其它類別的得分函數s大一個閾值 Δ

Δ:

s

y

i

≥s

j

syi≥sj+Δ

接下來,我們就可以根據這種思想定義SVM的損失函數:

L

i

=∑

j≠y

i

max(0,s

j

?s

y

i

+Δ)

Li=∑j≠yimax(0,sj?syi+Δ)

其中,y

i

yi表示正確的類別,j表示錯誤類別。從L

i

Li的表達式可以看出,只有當s

y

i

syi比s

j

sj大超過閾值 Δ

Δ 時,L

i

Li才為零,否則L

i

Li大於零。這種策略類似於距離最大化策略。

舉個例子來解釋L

i

Li的計算過程:例如得分函數s=[-1, 5, 4],y

1

y1是真實樣本,令Δ=3

Δ=3,則:

L

i

=max(0,?1?5+3)+max(0,4?5+3)=0+2=2

Li=max(0,?1?5+3)+max(0,4?5+3)=0+2=2

該損失函數由兩部分組成:y

1

y1與y

0

y0,y

1

y1與y

2

y2。由於y

1

y1與y

0

y0的差值大於閾值 Δ

Δ,則其損失函數為0;雖然y

1

y1比y

2

y2大,但差值小於閾值 Δ

Δ,則計算得到其損失函數為2。總的損失函數即為2。

這類損失函數的表達式一般稱作合頁損失函數「Hinge Loss Function」:

基於線性SVM的CIFAR-10圖像集分類

顯然,只有當s

j

?s

y

i

+Δ<0

sj?syi+Δ<0 時,損失函數才為零。

這種合頁損失函數的優點是體現了SVM距離最大化的思想;而且,損失函數大於零時,是線性函數,便於梯度下降演算法求導。

除了這種線性hinge loss SVM之外,還有squared hinge loss SVM,即採用平方的形式:

L

i

=∑

j≠y

i

max(0,s

j

?s

y

i

+Δ)

2

Li=∑j≠yimax(0,sj?syi+Δ)2

這種squared hinge loss SVM與linear hinge loss SVM相比較,特點是對違背間隔閾值要求的點加重懲罰,違背的越大,懲罰越大。某些實際應用中,squared hinge loss SVM的效果更好一些。具體使用哪個,可以根據實際問題,進行交叉驗證再確定。

對於超參數閾值 Δ

Δ,一般設置 Δ=1

Δ=1。因為,權重係數W是可伸縮的,直接影響著得分函數s的大小。所以說,Δ=1

Δ=1或 Δ=10

Δ=10,實際上沒有差別,對W的伸縮完全可以抵消掉 Δ

Δ 的數值影響。因此,通常把 Δ

Δ 設置為1即可。此時的損失函數為:

L

i

=∑

j≠y

i

max(0,s

j

?s

y

i

+1)

Li=∑j≠yimax(0,sj?syi+1)

SVM中,為了防止模型過擬合,可以使用正則化「Regularization」方法。例如使用L2正則化:

R(W)=∑

k

l

w

2

k,l

R(W)=∑k∑lwk,l2

引入正則化項之後的損失函數為:

L=1

N

L

i

+λR(W)

L=1NLi+λR(W)

其中,N是訓練樣本個數,λ

λ 是正則化參數,可調。一般來說,λ

λ 越大,對權重W的懲罰越大;λ

λ 越小,對權重W的懲罰越小。λ

λ 實際上是權衡損失函數第一項和第二項之間的關係:λ

λ 越大,對W的懲罰更大,犧牲正負樣本之間的間隔,可能造成欠擬合「underfit」;λ

λ 越小,得到的正負樣本間隔更大,但是W數值會變大,可能造成過擬合「overfit」。實際應用中,可通過交叉驗證,選擇合適的正則化參數λ

λ。

常數項b是否需要正則化?其實一般b是否正則化對模型的影響很小。可以對b進行正則化,也可以選擇不。實際應用中,通常只對權重係數W進行正則化。

4. 線性SVM實戰

首先,簡單介紹一下我們將要用到的經典數據集:CIFAR-10。

CIFAR-10數據集由60000張3×32×32的 RGB 彩色圖片構成,共10個分類。50000張訓練,10000張測試(交叉驗證)。這個數據集最大的特點在於將識別遷移到了普適物體,而且應用於多分類,是非常經典和常用的數據集。

基於線性SVM的CIFAR-10圖像集分類

這個數據集網上可以下載,我直接給大家下好了,放在雲盤裡,需要的自行領取。

鏈接:https://pan.baidu.com/s/1iZPwt72j-EpVUbLKgEpYMQ

密碼:vy1e

下面的代碼是隨機選擇每種類別下的5張圖片並顯示:

# Visualize some examples from the dataset.
# We show a few examples of training images from each class.
classes = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):
idxs = np.flatnonzero(y_train == y)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype("uint8"))
plt.axis("off")
if i == 0:
plt.title(cls)
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

基於線性SVM的CIFAR-10圖像集分類

接下來,就是對SVM計算hinge loss,包含L2正則化,代碼如下:

scores = X.dot(W)
correct_class_score = scores[range(num_train), list(y)].reshape(-1,1) # (N,1)
margin = np.maximum(0, scores - correct_class_score + 1)
margin[range(num_train), list(y)] = 0
loss = np.sum(margin) / num_train + 0.5 * reg * np.sum(W * W)
1
2
3
4
5

計算W梯度的代碼如下:

num_classes = W.shape[1]
inter_mat = np.zeros((num_train, num_classes))
inter_mat[margin > 0] = 1
inter_mat[range(num_train), list(y)] = 0
inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)
dW = (X.T).dot(inter_mat)
dW = dW/num_train + reg*W
1
2
3
4
5
6
7
8

根據SGD演算法,每次迭代後更新W:

W -= learning_rate * dW
1

訓練過程中,使用交叉驗證的方法選擇最佳的學習因子 learning_rate 和正則化參數 reg,代碼如下:

learning_rates = [1.4e-7, 1.5e-7, 1.6e-7]
regularization_strengths = [8000.0, 9000.0, 10000.0, 11000.0, 18000.0, 19000.0, 20000.0, 21000.0]
results = {}
best_lr = None
best_reg = None
best_val = -1 # The highest validation accuracy that we have seen so far.
best_svm = None # The LinearSVM object that achieved the highest validation rate.
for lr in learning_rates:
for reg in regularization_strengths:
svm = LinearSVM()
loss_history = svm.train(X_train, y_train, learning_rate = lr, reg = reg, num_iters = 2000)
y_train_pred = svm.predict(X_train)
accuracy_train = np.mean(y_train_pred == y_train)
y_val_pred = svm.predict(X_val)
accuracy_val = np.mean(y_val_pred == y_val)
if accuracy_val > best_val:
best_lr = lr
best_reg = reg
best_val = accuracy_val
best_svm = svm
results[(lr, reg)] = accuracy_train, accuracy_val
print("lr: %e reg: %e train accuracy: %f val accuracy: %f" %
(lr, reg, results[(lr, reg)][0], results[(lr, reg)][1]))
print("Best validation accuracy during cross-validation:
lr = %e, reg = %e, best_val = %f" %
(best_lr, best_reg, best_val))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

訓練結束後,選擇最佳的學習因子 learning_rate 和正則化參數 reg,在測試圖片集上進行驗證,代碼如下:

# Evaluate the best svm on test set
y_test_pred = best_svm.predict(X_test)
test_accuracy = np.mean(y_test == y_test_pred)
print("linear SVM on raw pixels final test set accuracy: %f" % test_accuracy)
1
2
3
4
5

linear SVM on raw pixels final test set accuracy: 0.384000

最後,有個比較好玩的操作,我們可以將訓練好的權重W可視化:

# Visualize the learned weights for each class.
# Depending on your choice of learning rate and regularization strength, these may
# or may not be nice to look at.
w = best_svm.W[:-1,:] # strip out the bias
w = w.reshape(32, 32, 3, 10)
w_min, w_max = np.min(w), np.max(w)
classes = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
for i in range(10):
plt.subplot(2, 5, i + 1)
# Rescale the weights to be between 0 and 255
wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)
plt.imshow(wimg.astype("uint8"))
plt.axis("off")
plt.title(classes[i])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

基於線性SVM的CIFAR-10圖像集分類

可以明顯看出,由W重構的圖片具有所屬樣本類別相似的地方,這正是線性SVM學習到的東西。

5. 總結

本文講述的線性SVM利用距離間隔最大的思想,利用hinge loss的優化策略,來構建一個機器學習模型,並將這個簡單模型應用到CIFAR-10圖片集中進行訓練和測試。實際測試的準確率在40%左右。準確率雖然不是很高,但是此SVM是線性模型,沒有引入核函數構建非線性模型,也沒有使用AlexNet,VGG,GoogLeNet,ResNet等卷積網路。測試結果比隨機猜測10%要好很多,是一個不錯的可實操的有趣模型。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 程序員小新人學習 的精彩文章:

大眾點評點餐小程序開發經驗之發布與推廣
Swoole實現基於WebSocket的群聊私聊

TAG:程序員小新人學習 |