基於線性SVM的CIFAR-10圖像集分類
1. SVM的基本思想
簡單來說,支持向量機SVM就是在特徵空間中找到一條最佳的分類超平面,能夠讓正、負樣本距離該超平面的間隔(margin)最大化。
以二維平面為例,確定一條直線對正負樣本進行分類,如下圖所示:
很明顯,雖然分類線H1、H2、H3都能夠將正負樣本完全分開,但是毫無疑問H3更好一些。原因是正負樣本距離H3都足夠遠,即間隔「margin」最大。這就是SVM的基本思想:盡量讓所有樣本距離分類超平面越遠越好。
2. 線性分類與得分函數
在線性分類器演算法中,輸入為x,輸出為y,令權重係數為W,常數項係數為b。我們定義得分函數s為:
s=Wx+b
s=Wx+b
這是線性分類器的一般形式,得分函數s所屬類別值越大,表示預測該類別的概率越大。
以圖像識別為例,共有3個類別「cat,dog,ship」。令輸入x的特徵維度為4「即包含4個像素值」,W的維度是3x4,b的維度是3x1。在W和b確定後,得到各個類別的得分函數s為:
由上圖可知,因為總有3個類別,得分函數s是3x1的向量。其中,cat score=-96.8,dog score=437.9,ship score=61.95。從s的值來說,dog score最高,cat score最低,則預測為狗的概率更大一些。而該圖片真實標籤是一隻貓,顯然,從得分函數s上來看,該線性分類器的預測結果是錯誤的。
通常為了簡化計算,我們直接將W和b整合成一個矩陣,同時將x額外增加一個全為1的維度。這樣,得分函數s的表達式得到了簡化:
W:=[W b]
W:=[W b]
x:=[x; 1]
x:=[x; 1]
s=Wx
s=Wx
示例圖如下:
3. 優化策略與損失函數
通常來說,SVM的優化策略是樣本到分類超平面的距離最大化。也就是說盡量讓正負樣本距離分類超平面有足夠寬的間隔,這是基於距離的衡量優化方式。針對上文提到的例子,圖片真實標籤是一隻貓,但是得到的s值卻是最低的,顯然這不是我們希望看到的。最好的情況應該是cat score最高。這樣才能保證預測cat的概率更大。此時,利用SVM的間隔最大化的思想,就要求cat score不僅僅要大於其它類別的s值,而且要達到一定的程度,可以說有個最低閾值。
因此,這種新的SVM優化策略可以這樣理解:正確類別對應的得分函數s應該比其它類別的得分函數s大一個閾值 Δ
Δ:
s
y
i
≥s
j
+Δ
syi≥sj+Δ
接下來,我們就可以根據這種思想定義SVM的損失函數:
L
i
=∑
j≠y
i
max(0,s
j
?s
y
i
+Δ)
Li=∑j≠yimax(0,sj?syi+Δ)
其中,y
i
yi表示正確的類別,j表示錯誤類別。從L
i
Li的表達式可以看出,只有當s
y
i
syi比s
j
sj大超過閾值 Δ
Δ 時,L
i
Li才為零,否則L
i
Li大於零。這種策略類似於距離最大化策略。
舉個例子來解釋L
i
Li的計算過程:例如得分函數s=[-1, 5, 4],y
1
y1是真實樣本,令Δ=3
Δ=3,則:
L
i
=max(0,?1?5+3)+max(0,4?5+3)=0+2=2
Li=max(0,?1?5+3)+max(0,4?5+3)=0+2=2
該損失函數由兩部分組成:y
1
y1與y
0
y0,y
1
y1與y
2
y2。由於y
1
y1與y
0
y0的差值大於閾值 Δ
Δ,則其損失函數為0;雖然y
1
y1比y
2
y2大,但差值小於閾值 Δ
Δ,則計算得到其損失函數為2。總的損失函數即為2。
這類損失函數的表達式一般稱作合頁損失函數「Hinge Loss Function」:
顯然,只有當s
j
?s
y
i
+Δ<0
sj?syi+Δ<0 時,損失函數才為零。
這種合頁損失函數的優點是體現了SVM距離最大化的思想;而且,損失函數大於零時,是線性函數,便於梯度下降演算法求導。
除了這種線性hinge loss SVM之外,還有squared hinge loss SVM,即採用平方的形式:
L
i
=∑
j≠y
i
max(0,s
j
?s
y
i
+Δ)
2
Li=∑j≠yimax(0,sj?syi+Δ)2
這種squared hinge loss SVM與linear hinge loss SVM相比較,特點是對違背間隔閾值要求的點加重懲罰,違背的越大,懲罰越大。某些實際應用中,squared hinge loss SVM的效果更好一些。具體使用哪個,可以根據實際問題,進行交叉驗證再確定。
對於超參數閾值 Δ
Δ,一般設置 Δ=1
Δ=1。因為,權重係數W是可伸縮的,直接影響著得分函數s的大小。所以說,Δ=1
Δ=1或 Δ=10
Δ=10,實際上沒有差別,對W的伸縮完全可以抵消掉 Δ
Δ 的數值影響。因此,通常把 Δ
Δ 設置為1即可。此時的損失函數為:
L
i
=∑
j≠y
i
max(0,s
j
?s
y
i
+1)
Li=∑j≠yimax(0,sj?syi+1)
SVM中,為了防止模型過擬合,可以使用正則化「Regularization」方法。例如使用L2正則化:
R(W)=∑
k
∑
l
w
2
k,l
R(W)=∑k∑lwk,l2
引入正則化項之後的損失函數為:
L=1
N
L
i
+λR(W)
L=1NLi+λR(W)
其中,N是訓練樣本個數,λ
λ 是正則化參數,可調。一般來說,λ
λ 越大,對權重W的懲罰越大;λ
λ 越小,對權重W的懲罰越小。λ
λ 實際上是權衡損失函數第一項和第二項之間的關係:λ
λ 越大,對W的懲罰更大,犧牲正負樣本之間的間隔,可能造成欠擬合「underfit」;λ
λ 越小,得到的正負樣本間隔更大,但是W數值會變大,可能造成過擬合「overfit」。實際應用中,可通過交叉驗證,選擇合適的正則化參數λ
λ。
常數項b是否需要正則化?其實一般b是否正則化對模型的影響很小。可以對b進行正則化,也可以選擇不。實際應用中,通常只對權重係數W進行正則化。
4. 線性SVM實戰
首先,簡單介紹一下我們將要用到的經典數據集:CIFAR-10。
CIFAR-10數據集由60000張3×32×32的 RGB 彩色圖片構成,共10個分類。50000張訓練,10000張測試(交叉驗證)。這個數據集最大的特點在於將識別遷移到了普適物體,而且應用於多分類,是非常經典和常用的數據集。
這個數據集網上可以下載,我直接給大家下好了,放在雲盤裡,需要的自行領取。
鏈接:https://pan.baidu.com/s/1iZPwt72j-EpVUbLKgEpYMQ
密碼:vy1e
下面的代碼是隨機選擇每種類別下的5張圖片並顯示:
# Visualize some examples from the dataset.
# We show a few examples of training images from each class.
classes = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
num_classes = len(classes)
samples_per_class = 7
for y, cls in enumerate(classes):
idxs = np.flatnonzero(y_train == y)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype("uint8"))
plt.axis("off")
if i == 0:
plt.title(cls)
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
接下來,就是對SVM計算hinge loss,包含L2正則化,代碼如下:
scores = X.dot(W)
correct_class_score = scores[range(num_train), list(y)].reshape(-1,1) # (N,1)
margin = np.maximum(0, scores - correct_class_score + 1)
margin[range(num_train), list(y)] = 0
loss = np.sum(margin) / num_train + 0.5 * reg * np.sum(W * W)
1
2
3
4
5
計算W梯度的代碼如下:
num_classes = W.shape[1]
inter_mat = np.zeros((num_train, num_classes))
inter_mat[margin > 0] = 1
inter_mat[range(num_train), list(y)] = 0
inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)
dW = (X.T).dot(inter_mat)
dW = dW/num_train + reg*W
1
2
3
4
5
6
7
8
根據SGD演算法,每次迭代後更新W:
W -= learning_rate * dW
1
訓練過程中,使用交叉驗證的方法選擇最佳的學習因子 learning_rate 和正則化參數 reg,代碼如下:
learning_rates = [1.4e-7, 1.5e-7, 1.6e-7]
regularization_strengths = [8000.0, 9000.0, 10000.0, 11000.0, 18000.0, 19000.0, 20000.0, 21000.0]
results = {}
best_lr = None
best_reg = None
best_val = -1 # The highest validation accuracy that we have seen so far.
best_svm = None # The LinearSVM object that achieved the highest validation rate.
for lr in learning_rates:
for reg in regularization_strengths:
svm = LinearSVM()
loss_history = svm.train(X_train, y_train, learning_rate = lr, reg = reg, num_iters = 2000)
y_train_pred = svm.predict(X_train)
accuracy_train = np.mean(y_train_pred == y_train)
y_val_pred = svm.predict(X_val)
accuracy_val = np.mean(y_val_pred == y_val)
if accuracy_val > best_val:
best_lr = lr
best_reg = reg
best_val = accuracy_val
best_svm = svm
results[(lr, reg)] = accuracy_train, accuracy_val
print("lr: %e reg: %e train accuracy: %f val accuracy: %f" %
(lr, reg, results[(lr, reg)][0], results[(lr, reg)][1]))
print("Best validation accuracy during cross-validation:
lr = %e, reg = %e, best_val = %f" %
(best_lr, best_reg, best_val))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
訓練結束後,選擇最佳的學習因子 learning_rate 和正則化參數 reg,在測試圖片集上進行驗證,代碼如下:
# Evaluate the best svm on test set
y_test_pred = best_svm.predict(X_test)
test_accuracy = np.mean(y_test == y_test_pred)
print("linear SVM on raw pixels final test set accuracy: %f" % test_accuracy)
1
2
3
4
5
linear SVM on raw pixels final test set accuracy: 0.384000
最後,有個比較好玩的操作,我們可以將訓練好的權重W可視化:
# Visualize the learned weights for each class.
# Depending on your choice of learning rate and regularization strength, these may
# or may not be nice to look at.
w = best_svm.W[:-1,:] # strip out the bias
w = w.reshape(32, 32, 3, 10)
w_min, w_max = np.min(w), np.max(w)
classes = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
for i in range(10):
plt.subplot(2, 5, i + 1)
# Rescale the weights to be between 0 and 255
wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)
plt.imshow(wimg.astype("uint8"))
plt.axis("off")
plt.title(classes[i])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
可以明顯看出,由W重構的圖片具有所屬樣本類別相似的地方,這正是線性SVM學習到的東西。
5. 總結
本文講述的線性SVM利用距離間隔最大的思想,利用hinge loss的優化策略,來構建一個機器學習模型,並將這個簡單模型應用到CIFAR-10圖片集中進行訓練和測試。實際測試的準確率在40%左右。準確率雖然不是很高,但是此SVM是線性模型,沒有引入核函數構建非線性模型,也沒有使用AlexNet,VGG,GoogLeNet,ResNet等卷積網路。測試結果比隨機猜測10%要好很多,是一個不錯的可實操的有趣模型。
※大眾點評點餐小程序開發經驗之發布與推廣
※Swoole實現基於WebSocket的群聊私聊
TAG:程序員小新人學習 |