機器學習實戰札記
《機器學習實戰》一書介紹的第一個演算法是k-近鄰演算法。簡單的說,k-近鄰演算法採用測量不同特徵值之間的距離方法進行分類。其工作機制非常簡單:給定測試樣本,基於某種距離度量找出訓練集中與其最靠近的k個訓練樣本,然後基於這k個「鄰居」的信息來進行預測。
《機器學習實戰》一書給出的示例都是分類演算法,其實該演算法也適用於回歸任務。在分類任務中可使用「投票法」,即選擇這個k個樣本中出現最多的類別標記作為預測結果; 在回歸任務中可使用「平均法」,即將這k個樣本的實值輸出標記的平均值作為預測結果。
k-近鄰演算法實現上也比較簡單,以分類任務為例,首先是準備訓練樣本,訓練樣本都存在標籤,也就是我們知道樣本集中每一數據與所屬分類的對應關係。輸入沒有標籤的新數據後,將新數據的每個特徵與訓練樣本對應的特徵進行比較,然後演算法提取樣本集中特徵最相似數據(最近鄰)的分類標籤。一般來說,選擇k個最相似的數據,這就是k-近鄰演算法中k的出處。
從前面的分析可以看出,k-近鄰演算法沒有顯式的訓練過程,在訓練階段僅僅是把樣本保存起來,訓練時間開銷為零,待收到測試樣本後再進行處理。這個演算法存在兩個關鍵點:
k值如何選擇。在《機器學習實戰》和西瓜書上都沒有給出,只說了k通常是不大於20的整數。從實際測試來看,k取值不同,分類結果會有所不同,這個需要根據經驗來選擇。
距離計算方式。在《機器學習實戰》中採取的是歐式距離公式:
為了避免某個屬性的取值範圍過大,從而對整個距離的計算影響太大,可以採用數值歸一化,將取值範圍處理為0到1或-1到1之間,最簡單的公式就是:
從上述演算法上可以看出,該演算法的缺點:計算複雜度高(需要計算與每個訓練樣本之間的距離,通常訓練樣本比較大)、空間複雜度高。
當然這個演算法也有許多優點:精度高、對異常值不敏感、無數據輸入假定。
書中給出了一個使用k-近鄰演算法識別手寫數字的完整例子,其錯誤率為1.2%。這已經是很高的精度了。而且西瓜書還給出了一個簡化的證明,它的泛化錯誤率不超過貝葉斯最優分類器的錯誤率的兩倍!
如果使用手寫數字的訓練樣本預測印刷體數字樣本,錯誤率會達到多少?
56.9%
這也印證了機器學習中的NFL(沒有免費的午餐)定理。我們應該清楚的認識到,脫離具體問題,空泛地談論「什麼學習演算法更好」毫無意義。
參考:
《機器學習實戰》p15 - 31
《機器學習》p225 - p226
![](https://pic.pimg.tw/zzuyanan/1488615166-1259157397.png)
![](https://pic.pimg.tw/zzuyanan/1482887990-2595557020.jpg)
※機器學習老中醫:利用學習曲線診斷模型的偏差和方差
※數據科學、機器學習和人工智慧到底有什麼區別?
TAG:機器學習 |