當前位置:
首頁 > 新聞 > 機器學習之確定最佳聚類數目的10種方法

機器學習之確定最佳聚類數目的10種方法

雷鋒網 AI科技評論按,本文作者貝爾塔,原文載於知乎專欄數據分析與可視化,雷鋒網 AI科技評論獲其授權發布。

在聚類分析的時候確定最佳聚類數目是一個很重要的問題,比如kmeans函數就要你提供聚類數目這個參數,總不能兩眼一抹黑亂填一個吧。之前也被這個問題困擾過,看了很多博客,大多泛泛帶過。今天把看到的這麼多方法進行匯總以及代碼實現並盡量弄清每個方法的原理。

數據集選用比較出名的wine數據集進行分析

library(gclus)

data(wine)

head(wine)

Loading required package: cluster

因為我們要找一個數據集進行聚類分析,所以不需要第一列的種類標籤信息,因此去掉第一列。

同時注意到每一列的值差別很大,從1到100多都有,這樣會造成誤差,所以需要歸一化,用scale函數

dataset

dataset

去掉標籤之後就可以開始對數據集進行聚類分析了,下面就一一介紹各種確定最佳聚類數目的方法

判定方法

1.mclust包

mclust包是聚類分析非常強大的一個包,也是上課時老師給我們介紹的一個包,每次導入時有一種科技感 :) 幫助文檔非常詳盡,可以進行聚類、分類、密度分析

Mclust包方法有點「暴力」,聚類數目自定義,比如我選取的從1到20,然後一共14種模型,每一種模型都計算聚類數目從1到20的BIC值,最終確定最佳聚類數目,這種方法的思想很直接了當,但是弊端也就顯然易見了——時間複雜度太高,效率低。

library(mclust)

m_clust

summary(m_clust)

Gaussian finite mixture model fitted by EM algorithm

Mclust EVE (ellipsoidal, equal volume and orientation) model with 3 components:

log.likelihood   n  df       BIC       ICL

    -3032.45 178 156 -6873.257 -6873.549

Clustering table:

1  2  3

63 51 64

可見該函數已經把數據集聚類為3種類型了。數目分別為63、51、64。再畫出14個指標隨著聚類數目變化的走勢圖

plot(m_clust, "BIC")

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/9c2e7ed5c2a689fc2caf452112362db6.png" data-rawwidth="531" data-rawheight="527" class="origin_image zh-lightbox-thumb" width="531" data-original="https://pic4.zhimg.com/v2-5357b2dfa1d132078fb52f7ec8ea8faf_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/9c2e7ed5c2a689fc2caf452112362db6.png"/>

下表是這些模型的意義

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/92cc4b3be70b912cc2f43a3e8d204f9d.png" data-rawwidth="715" data-rawheight="532" class="origin_image zh-lightbox-thumb" width="715" data-original="https://pic2.zhimg.com/v2-5e9833600f6cff8e5233c12e61dc15b9_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/92cc4b3be70b912cc2f43a3e8d204f9d.png"/>

它們應該分別代表著相關性(完全正負相關——對角線、稍強正負相關——橢圓、無關——圓)等參數的改變對應的模型,研究清楚這些又是非常複雜的問題了,先按下表,知道BIC值越大則說明所選取的變數集合擬合效果越好。上圖中除了兩個模型一直遞增,其他的12模型數基本上都是在聚類數目為3的時候達到峰值,所以該演算法由此得出最佳聚類數目為3的結論。

mclust包還可以用於分類、密度估計等,這個包值得好好把玩。

注意:此BIC並不是貝葉斯信息準則!!!

最近上課老師講金融模型時提到了BIC值,說BIC值越小模型效果越好,頓時想起這裡是在圖中BIC極大值為最佳聚類數目,然後和老師探討了這個問題,之前這裡誤導大家了,Mclust包裡面的BIC並不是貝葉斯信息準則。

1.維基上的貝葉斯信息準則定義

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/3b66ee7c7a243bf9c7c3b22a23ad55af.png" data-rawwidth="292" data-rawheight="53" class="content_image" width="292" _src="https://static.leiphone.com/uploads/new/article/pic/201710/3b66ee7c7a243bf9c7c3b22a23ad55af.png"/>

與log(likelihood)成反比,極大似然估計是值越大越好,那麼BIC值確實是越小模型效果越好

2.Mclust包中的BIC定義[3]

這是Mclust包裡面作者定義的「BIC值」,此BIC非彼BIC,這裡是作者自己定義的BIC,可以看到,這裡的BIC與極大似然估計是成正比的,所以這裡是BIC值越大越好,與貝葉斯信息準則值越小模型越好的結論並不衝突

2.Nbclust包

Nbclust包是我在《R語言實戰》上看到的一個包,思想和mclust包比較相近,也是定義了幾十個評估指標,然後聚類數目從2遍歷到15(自己設定),然後通過這些指標看分別在聚類數為多少時達到最優,最後選擇指標支持數最多的聚類數目就是最佳聚類數目。

library(NbClust)

set.seed(1234) #因為method選擇的是kmeans,所以如果不設定種子,每次跑得結果可能不同

nb_clust      min.nc=2, max.nc=15, method = "kmeans",

     index = "alllong", alphaBeale = 0.1)

*** : The Hubert index is a graphical method of determining the number of clusters.

             In the plot of Hubert index, we seek a significant knee that corresponds to a

             significant increase of the value of the measure i.e the significant peak in Hubert

             index second differences plot.

*** : The D index is a graphical method of determining the number of clusters.

             In the plot of D index, we seek a significant knee (the significant peak in Dindex

             second differences plot) that corresponds to a significant increase of the value of

             the measure.

*******************************************************************

* Among all indices:                                              

* 5 proposed 2 as the best number of clusters

* 16 proposed 3 as the best number of clusters

* 1 proposed 10 as the best number of clusters

* 1 proposed 12 as the best number of clusters

* 1 proposed 14 as the best number of clusters

* 3 proposed 15 as the best number of clusters

                ***** Conclusion *****                          

* According to the majority rule, the best number of clusters is  3

*******************************************************************

barplot(table(nb_clust$Best.nc[1,]),xlab = "聚類數",ylab = "支持指標數")

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/343439c0965ccbde22dbc1d2185b1a13.png" data-rawwidth="533" data-rawheight="530" class="origin_image zh-lightbox-thumb" width="533" data-original="https://pic3.zhimg.com/v2-a08e3cf366ada95c9fa43d429a00f16a_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/343439c0965ccbde22dbc1d2185b1a13.png"/>

可以看到有16個指標支持最佳聚類數目為3,5個指標支持聚類數為2,所以該方法推薦的最佳聚類數目為3.

3. 組內平方誤差和——拐點圖

想必之前動輒幾十個指標,這裡就用一個最簡單的指標——sum of squared error (SSE)組內平方誤差和來確定最佳聚類數目。這個方法也是出於《R語言實戰》,自定義的一個求組內誤差平方和的函數。

wssplot

 wss

 for (i in 2:nc){

     set.seed(seed)

     wss[i]

     }

 plot(1:nc, wss, type="b", xlab="Number of Clusters",

     ylab="Within groups sum of squares")}

wssplot(dataset)

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/39ae40d81ccf72c41b48a28e609d661e.png" data-rawwidth="533" data-rawheight="531" class="origin_image zh-lightbox-thumb" width="533" data-original="https://pic2.zhimg.com/v2-25b396108e9b5da6094c2097888f2251_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/39ae40d81ccf72c41b48a28e609d661e.png"/>

隨著聚類數目增多,每一個類別中數量越來越少,距離越來越近,因此WSS值肯定是隨著聚類數目增多而減少的,所以關注的是斜率的變化,但WWS減少得很緩慢時,就認為進一步增大聚類數效果也並不能增強,存在得這個「肘點」就是最佳聚類數目,從一類到三類下降得很快,之後下降得很慢,所以最佳聚類個數選為三

另外也有現成的包(factoextra)可以調用

library(factoextra)

library(ggplot2)

set.seed(1234)

fviz_nbclust(dataset, kmeans, method = "wss") +

 geom_vline(xintercept = 3, linetype = 2)

Loading required package: ggplot2

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/45de0c17eee70394a867f5889f7d3098.png" data-rawwidth="537" data-rawheight="534" class="origin_image zh-lightbox-thumb" width="537" data-original="https://pic3.zhimg.com/v2-b80ac3d187e619c5fdaf46a3d3c9361e_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/45de0c17eee70394a867f5889f7d3098.png"/>

選定為3類為最佳聚類數目

用該包下的fviz_cluster函數可視化一下聚類結果

km.res

fviz_cluster(km.res, data = dataset)

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/4c6ed79b2deec361f41a79be29eca253.png" data-rawwidth="532" data-rawheight="530" class="origin_image zh-lightbox-thumb" width="532" data-original="https://pic2.zhimg.com/v2-9ff52d6a4c1dfeedff83fa03c895e2b9_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/4c6ed79b2deec361f41a79be29eca253.png"/>

4. PAM(Partitioning Around Medoids) 圍繞中心點的分割演算法

k-means演算法取得是均值,那麼對於異常點其實對其的影響非常大,很可能這種孤立的點就聚為一類,一個改進的方法就是PAM演算法,也叫k-medoids clustering

首先通過fpc包中的pamk函數得到最佳聚類數目

3

pamk函數不需要提供聚類數目,也會直接自動計算出最佳聚類數,這裡也得到為3

得到聚類數提供給cluster包下的pam函數並進行可視化

library(cluster)

clusplot(pam(dataset, pamk.best$nc))

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/3399bd76503f61d284bdccca0023d899.png" data-rawwidth="530" data-rawheight="532" class="origin_image zh-lightbox-thumb" width="530" data-original="https://pic3.zhimg.com/v2-387e44d991195f07e55ffbdeb0c479be_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/3399bd76503f61d284bdccca0023d899.png"/>

5.Calinsky criterion

這個評估標準定義[5]如下:

其中,k是聚類數,N是樣本數,SSw是我們之前提到過的組內平方和誤差, SSb是組與組之間的平方和誤差,SSw越小,SSb越大聚類效果越好,所以Calinsky criterion值一般來說是越大,聚類效果越好

library(vegan)

ca_clust

ca_clust$results

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/a12ebaef7daa5788c09e5eb3ba08c6d0.png" data-rawwidth="592" data-rawheight="142" class="origin_image zh-lightbox-thumb" width="592" data-original="https://pic4.zhimg.com/v2-aa28be2976a4f350d80506c45026df97_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/a12ebaef7daa5788c09e5eb3ba08c6d0.png"/>

可以看到該函數把組內平方和誤差和Calinsky都計算出來了,可以看到calinski在聚類數為3時達到最大值。

3

畫圖出來觀察一下

plot(fit, sortg = TRUE, grpmts.plot = TRUE)

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/699416f926e98813797da0ae31c45ce4.png" data-rawwidth="534" data-rawheight="533" class="origin_image zh-lightbox-thumb" width="534" data-original="https://pic4.zhimg.com/v2-fc2cda58b5ba4856c681a731bc4906eb_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/699416f926e98813797da0ae31c45ce4.png"/>

注意到那個紅點就是對應的最大值,自帶的繪圖橫軸縱軸取的可能不符合我們的直覺,把數據取出來自己單獨畫一下

calinski

calinski$cluster

library(ggplot2)

ggplot(calinski,aes(x = calinski[,2], y = calinski[,1]))+geom_line()

Warning message:

"Removed 1 rows containing missing values (geom_path)."

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/d1a5ca803f794122475c83a0f2ca700a.png" data-rawwidth="533" data-rawheight="533" class="origin_image zh-lightbox-thumb" width="533" data-original="https://pic2.zhimg.com/v2-7655fc4a181aa14c3614f060e7981575_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/d1a5ca803f794122475c83a0f2ca700a.png"/>

這個看上去直觀多了。這就很清晰的可以看到在聚類數目為3時,calinski指標達到了最大值,所以最佳數目為3

6.Affinity propagation (AP) clustering

這個本質上是類似kmeans或者層次聚類一樣,是一種聚類方法,因為不需要像kmeans一樣提供聚類數,會自動算出最佳聚類數,因此也放到這裡作為一種計算最佳聚類數目的方法。

AP演算法的基本思想是將全部樣本看作網路的節點,然後通過網路中各條邊的消息傳遞計算出各樣本的聚類中心。聚類過程中,共有兩種消息在各節點間傳遞,分別是吸引度( responsibility)和歸屬度(availability) 。AP演算法通過迭代過程不斷更新每一個點的吸引度和歸屬度值,直到產生m個高質量的Exemplar(類似於質心),同時將其餘的數據點分配到相應的聚類中[7]

library(apcluster)

ap_clust

length(ap_clust@clusters)

15

該聚類方法推薦的最佳聚類數目為15,再用熱力圖可視化一下

heatmap(ap_clust)

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/dbc870375b6d45be952ee72e882fbbc0.png" data-rawwidth="527" data-rawheight="530" class="origin_image zh-lightbox-thumb" width="527" data-original="https://pic3.zhimg.com/v2-b982fd6c85860755226f4ae347397252_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/dbc870375b6d45be952ee72e882fbbc0.png"/>

選x或者y方向看(對稱),可以數出來「葉子節點」一共15個

7. 輪廓係數Average silhouette method

輪廓係數是類的密集與分散程度的評價指標。

a(i)是測量組內的相似度,b(i)是測量組間的相似度,s(i)範圍從-1到1,值越大說明組內吻合越高,組間距離越遠——也就是說,輪廓係數值越大,聚類效果越好[9]

require(cluster)

library(factoextra)

fviz_nbclust(dataset, kmeans, method = "silhouette")

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/320ffea3de355eba9f40c34a92397420.png" data-rawwidth="535" data-rawheight="531" class="origin_image zh-lightbox-thumb" width="535" data-original="https://pic3.zhimg.com/v2-f24229d24e86ded6219e4636e9849292_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/320ffea3de355eba9f40c34a92397420.png"/>

可以看到也是在聚類數為3時輪廓係數達到了峰值,所以最佳聚類數為3

8. Gap Statistic

之前我們提到了WSSE組內平方和誤差,該種方法是通過找「肘點」來找到最佳聚類數,肘點的選擇並不是那麼清晰,因此斯坦福大學的Robert等教授提出了Gap Statistic方法,定義的Gap值為[9]

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/c62dd70be24d437c31a08065d289ac33.png" data-rawwidth="300" data-rawheight="72" class="content_image" width="300" _src="https://static.leiphone.com/uploads/new/article/pic/201710/c62dd70be24d437c31a08065d289ac33.png"/>

取對數的原因是因為Wk的值可能很大

通過這個式子來找出Wk跌落最快的點,Gap最大值對應的k值就是最佳聚類數

library(cluster)

set.seed(123)

gap_clust

gap_clust

Clustering Gap statistic ["clusGap"] from call:

clusGap(x = dataset, FUNcluster = kmeans, K.max = 10, B = 500,     verbose = interactive())

B=500 simulated reference sets, k = 1..10; spaceH0="scaledPCA"

--> Number of clusters (method 'firstSEmax', SE.factor=1): 3

       logW   E.logW       gap     SE.sim

[1,] 5.377557 5.863690 0.4861333 0.01273873

[2,] 5.203502 5.758276 0.5547745 0.01420766

[3,] 5.066921 5.697322 0.6304006 0.01278909

[4,] 5.023936 5.651618 0.6276814 0.01243239

[5,] 4.993720 5.615174 0.6214536 0.01251765

[6,] 4.962933 5.584564 0.6216311 0.01165595

[7,] 4.943241 5.556310 0.6130690 0.01181831

[8,] 4.915582 5.531834 0.6162518 0.01139207

[9,] 4.881449 5.508514 0.6270646 0.01169532

[10,] 4.855837 5.487005 0.6311683 0.01198264

library(factoextra)

fviz_gap_stat(gap_clust)

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/dd77400b222298012bc677bcc05cb7a8.png" data-rawwidth="532" data-rawheight="534" class="origin_image zh-lightbox-thumb" width="532" data-original="https://pic3.zhimg.com/v2-64bd9c3b553db08a095966397e88e26e_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/dd77400b222298012bc677bcc05cb7a8.png"/>

可以看到也是在聚類數為3的時候gap值取到了最大值,所以最佳聚類數為3

9.層次聚類

層次聚類是通過可視化然後人為去判斷大致聚為幾類,很明顯在共同父節點的一顆子樹可以被聚類為一個類

h_dist

h_clust

plot(h_clust, hang = -1, labels = FALSE)

rect.hclust(h_clust,3)

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/ff45e8a530a86955edfe5cbe6478c20d.png" data-rawwidth="529" data-rawheight="532" class="origin_image zh-lightbox-thumb" width="529" data-original="https://pic1.zhimg.com/v2-3d1b5ba7f76ce0ff1240ca497a613440_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/ff45e8a530a86955edfe5cbe6478c20d.png"/>

10.clustergram

最後一種演算法是Tal Galili[10]大牛自己定義的一種聚類可視化的展示,繪製隨著聚類數目的增加,所有成員是如何分配到各個類別的。該代碼沒有被製作成R包,可以去Galili介紹頁面)裡面的github地址找到源代碼跑一遍然後就可以用這個函數了,因為源代碼有點長我就不放博客裡面了,直接放出運行代碼的截圖。

clustergram(dataset, k.range = 2:8, line.width = 0.004)

Loading required package: colorspace

Loading required package: plyr

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/f31e6e16ca2b9f49e4af337493eb703a.png" data-rawwidth="527" data-rawheight="530" class="origin_image zh-lightbox-thumb" width="527" data-original="https://pic4.zhimg.com/v2-9797f9f7e0288fdf2d034531db315663_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/f31e6e16ca2b9f49e4af337493eb703a.png"/>

隨著K的增加,從最開始的兩類到最後的八類,圖肯定是越到後面越密集。通過這個圖判斷最佳聚類數目的方法應該是看隨著K每增加1,分出來的線越少說明在該k值下越穩定。比如k=7到k=8,假設k=7是很好的聚類數,那分成8類時應該可能只是某一類分成了兩類,其他6類都每怎麼變。反應到圖中應該是有6簇平行線,有一簇分成了兩股,而現在可以看到從7到8,線完全亂了,說明k=7時效果並不好。按照這個分析,k=3到k=4時,第一股和第三股幾本沒變,就第二股拆成了2類,所以k=3是最佳聚類數目

方法匯總與比較

wine數據集我們知道其實是分為3類的,以上10種判定方法中:

層次聚類和clustergram方法、肘點圖法,需要人工判定,雖然可以得出大致的最佳聚類數,但演算法本身不會給出最佳聚類數

除了Affinity propagation (AP) clustering 給出最佳聚類數為15,剩下6種全都是給出最佳聚類數為3

選用上次文本挖掘的矩陣進行分析(667*1623)

mclust效果很差,14種模型只有6種有結果

bclust報錯

SSE可以運行

fpc包中的pamk函數聚成2類,明顯不行

Calinsky criterion聚成2類

Affinity propagation (AP) clustering 聚成28類,相對靠譜

輪廓係數Average silhouette聚類2類

gap-Statistic跑不出結果

可見上述方法中有的因為數據太大不能運行,有的結果很明顯不對,一個可能是數據集的本身的原因(缺失值太多等),但是也告訴了我們在確定最佳聚類數目的時候需要多嘗試幾種方法,並沒有固定的套路,然後選擇一種可信度較高的聚類數目。

最後再把這10種方法總結一下:

<img src="https://static.leiphone.com/uploads/new/article/pic/201710/e62c8e97de526716f340c418282308ab.png" data-rawwidth="534" data-rawheight="532" class="origin_image zh-lightbox-thumb" width="534" data-original="https://pic3.zhimg.com/v2-d43a3c6a70d48edb2add13722f224ac2_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201710/e62c8e97de526716f340c418282308ab.png"/>

參考文獻

[1]R語言實戰第二版

[2]Partitioning cluster analysis: Quick start guide - Unsupervised Machine Learning

[3]BIC:http://www.stat.washington.edu/raftery/Research/PDF/fraley1998.pdf

[4]Cluster analysis in R: determine the optimal number of clusters

[5]Calinski-Harabasz Criterion:Calinski-Harabasz criterion clustering evaluation object

[6]Determining the optimal number of clusters: 3 must known methods - Unsupervised Machine Learning

[7] affinity-propagation:聚類演算法Affinity Propagation(AP)

[8]輪廓係數https://en.wikipedia.org/wiki/Silhouette(clustering))

[9]gap statistic-Tibshirani R, Walther G, Hastie T. Estimating the number of clusters in a data set via the gap statistic[J]. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 2001, 63(2): 411-423.

[10]ClustergramsClustergram: visualization and diagnostics for cluster analysis (R code)


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 雷鋒網 的精彩文章:

Uber 的自動駕駛汽車要給匹茲堡帶來一場革命,但革誰的命?
Efinix可編程晶元:可進一步推動人工智慧技術發展
谷歌發布會ARCore為什麼只是低調現身?
2023年之前推出20款純電動車型,通用汽車開啟了一場自我革命
人民日報再評演算法:演算法決定內容,價值取向跑偏;iPhone 8內地首爆,沒充過電,屏幕側邊裂開

TAG:雷鋒網 |