當前位置:
首頁 > 知識 > 回歸樹的原理及其 Python 實現

回歸樹的原理及其 Python 實現

(點擊

上方公號

,快速關注我們)




本文來自作者「

李小文

」的投稿


公號平台,首發「Python開發者」




提到回歸樹,相信大家應該都不會覺得陌生(不陌生你點進來幹嘛[捂臉]),大名鼎鼎的 GBDT 演算法就是用回歸樹組合而成的。本文就回歸樹的基本原理進行講解,並手把手、肩並肩地帶您實現這一演算法。

完整實現代碼請參考 github: 

https://github.com/tushushu/Imylu/blob/master/regression_tree.py


1. 原理篇


我們用人話而不是大段的數學公式,來講講回歸樹是怎麼一回事。


1.1 最簡單的模型


如果預測某個連續變數的大小,最簡單的模型之一就是用平均值。比如同事的平均年齡是 28 歲,那麼新來了一批同事,在不知道這些同事的任何信息的情況下,直覺上用平均值 28 來預測是比較準確的,至少比 0 歲或者 100 歲要靠譜一些。我們不妨證明一下我們的直覺:





1.2 加一點難度


仍然是預測同事年齡,這次我們預先知道了同事的職級,假設職級的範圍是整數1-10,如何能讓這個信息幫助我們更加準確的預測年齡呢?


一個思路是根據職級把同事分為兩組,這兩組分別應用我們之前提到的「平均值」模型。比如職級小於 5 的同事分到A組,大於或等於5的分到 B 組,A 組的平均年齡是 25 歲,B 組的平均年齡是 35 歲。如果新來了一個同事,職級是 3,應該被分到 A 組,我們就預測他的年齡是 25 歲。


1.3 最佳分割點


還有一個問題待解決,如何取一個最佳的分割點對不同職級的同事進行分組呢?


我們嘗試所有 m 個可能的分割點 P_i,沿用之前的損失函數,對 A、B 兩組分別計算 Loss 並相加得到 L_i。最小的 L_i 所對應的 P_i 就是我們要找的「最佳分割點」。


1.4 運用多個變數


再複雜一些,如果我們不僅僅知道了同事的職級,還知道了同事的工資(貌似不科學),該如何預測同事的年齡呢?


我們可以分別根據職級、工資計算出職級和工資的最佳分割點P_1, P_2,對應的Loss L_1, L_2。然後比較L_1和L2,取較小者。假設L_1 < L_2,那麼按照P_1把不同職級的同事分為A、B兩組。在A、B組內分別計算工資所對應的分割點,再分為C、D兩組。這樣我們就得到了AC, AD, BC, BD四組同事以及對應的平均年齡用於預測。


1.5 答案揭曉


如何實現這種1 to 2, 2 to 4, 4 to 8的演算法呢?


熟悉數據結構的同學自然會想到二叉樹,這種樹被稱為回歸樹,顧名思義利用樹形結構求解回歸問題。


2. 實現篇


本人用全宇宙最簡單的編程語言——Python實現了回歸樹演算法,沒有依賴任何第三方庫,便於學習和使用。簡單說明一下實現過程,更詳細的注釋請參考本人github上的代碼。


2.1 創建Node類


初始化,存儲預測值、左右結點、特徵和分割點



class

Node

(

object

)

:


    

def __init__

(

self

,

score

=

None

)

:


        

self

.

score

=

score


        

self

.

left

=

None


        

self

.

right

=

None


        

self

.

feature

=

None


        

self

.

split

=

None




2.2 創建回歸樹類


初始化,存儲根節點和樹的高度。



class

RegressionTree

(

object

)

:


    

def __init__

(

self

)

:


        

self

.

root

=

Node

()


        

self

.

height

=

0




2.3 計算分割點、MSE


根據自變數X、因變數y、X元素中被取出的行號idx,列號feature以及分割點split,計算分割後的MSE。注意這裡為了減少計算量,用到了方差公式:





2.4 計算最佳分割點


遍歷特徵某一列的所有的不重複的點,找出MSE最小的點作為最佳分割點。如果特徵中沒有不重複的元素則返回None。



def _choose_split_point

(

self

,

X

,

y

,

idx

,

feature

)

:


    

unique

=

set

([

X

[

i

][

feature

]

for

i

in

idx

])


    

if

len

(

unique

)

==

1

:


        

return

None


 


    

unique

.

remove

(

min

(

unique

))


    

mse

,

split

,

split_avg

=

min

(


        

(

self

.

_get_split_mse

(

X

,

y

,

idx

,

feature

,

split

)


            

for

split

in

unique

),

key

=

lambda

x

:

x

[

0

])


    

return

mse

,

feature

,

split

,

split_avg




2.5 選擇最佳特徵


遍歷所有特徵,計算最佳分割點對應的MSE,找出MSE最小的特徵、對應的分割點,左右子節點對應的均值和行號。如果所有的特徵都沒有不重複元素則返回None



def _choose_feature

(

self

,

X

,

y

,

idx

)

:


    

m

=

len

(

X

[

0

])


    

split_rets

=

[

x

for

x

in

map

(

lambda

x

:

self

.

_choose_split_point

(


        

X

,

y

,

idx

,

x

),

range

(

m

))

if

x

is

not

None

]


 


    

if

split_rets

==

[]

:


        

return

None


    

_

,

feature

,

split

,

split_avg

=

min

(


        

split_rets

,

key

=

lambda

x

:

x

[

0

])


 


    

idx_split

=

[[],

[]]


    

while

idx

:


        

i

=

idx

.

pop

()


        

xi

=

X

[

i

][

feature

]


        

if

xi

<

split

:


            

idx_split

[

0

].

append

(

i

)


        

else

:


            

idx_split

[

1

].

append

(

i

)


    

return

feature

,

split

,

split_avg

,

idx_split




2.6 規則轉文字


將規則用文字表達出來,方便我們查看規則。



def _expr2literal

(

self

,

expr

)

:


    

feature

,

op

,

split

=

expr


    

op

=

">="

if

op

==

1

else

"<"


    

return

"Feature%d %s %.4f"

%

(

feature

,

op

,

split

)




2.7 獲取規則


將回歸樹的所有規則都用文字表達出來,方便我們了解樹的全貌。這裡用到了隊列+廣度優先搜索。有興趣也可以試試遞歸或者深度優先搜索。



def _get_rules

(

self

)

:


    

que

=

[[

self

.

root

,

[]]]


    

self

.

rules

=

[]


 


    

while

que

:


        

nd

,

exprs

=

que

.

pop

(

0

)


        

if

not

(

nd

.

left

or

nd

.

right

)

:


            

literals

=

list

(

map

(

self

.

_expr2literal

,

exprs

))


            

self

.

rules

.

append

([

literals

,

nd

.

score

])


 


        

if

nd

.

left

:


            

rule_left

=

copy

(

exprs

)


            

rule_left

.

append

([

nd

.

feature

,

-

1

,

nd

.

split

])


            

que

.

append

([

nd

.

left

,

rule_left

])


 


        

if

nd

.

right

:


            

rule_right

=

copy

(

exprs

)


            

rule_right

.

append

([

nd

.

feature

,

1

,

nd

.

split

])


            

que

.

append

([

nd

.

right

,

rule_right

])




2.8 訓練模型


仍然使用隊列+廣度優先搜索,訓練模型的過程中需要注意:




  1. 控制樹的最大深度max_depth;



  2. 控制分裂時最少的樣本量min_samples_split;



  3. 葉子結點至少有兩個不重複的y值;



  4. 至少有一個特徵是沒有重複值的。





def fit

(

self

,

X

,

y

,

max_depth

=

5

,

min_samples_split

=

2

)

:


    

self

.

root

=

Node

()


    

que

=

[[

0

,

self

.

root

,

list

(

range

(

len

(

y

)))]]


 


    

while

que

:


        

depth

,

nd

,

idx

=

que

.

pop

(

0

)


 


        

if

depth

==

max_depth

:


            

break


 


        

if

len

(

idx

)

<

min_samples_split

or


                

set

(

map

(

lambda

i

:

y

[

i

],

idx

))

==

1

:


            

continue


 


        

feature_rets

=

self

.

_choose_feature

(

X

,

y

,

idx

)


        

if

feature_rets

is

None

:


            

continue


 


        

nd

.

feature

,

nd

.

split

,

split_avg

,

idx_split

=

feature_rets


        

nd

.

left

=

Node

(

split_avg

[

0

])


        

nd

.

right

=

Node

(

split_avg

[

1

])


        

que

.

append

([

depth

+

1

,

nd

.

left

,

idx_split

[

0

]])


        

que

.

append

([

depth

+

1

,

nd

.

right

,

idx_split

[

1

]])


 


    

self

.

height

=

depth


    

self

.

_get_rules

()




2.9 列印規則


模型訓練完畢,查看一下模型生成的規則



def print_rules

(

self

)

:


    

for

i

,

rule

in

enumerate

(

self

.

rules

)

:


        

literals

,

score

=

rule


        print

(

"Rule %d: "

%

i

,

" | "

.

join

(


            

literals

)

+

" => split_hat %.4f"

%

score

)




2.10 預測一個樣本



def _predict

(

self

,

row

)

:


    

nd

=

self

.

root


    

while

nd

.

left

and

nd

.

right

:


        

if

row

[

nd

.

feature

]

<

nd

.

split

:


            

nd

=

nd

.

left


        

else

:


            

nd

=

nd

.

right


    

return

nd

.

score




2.11 預測多個樣本



def predict

(

self

,

X

)

:


    

return

[

self

.

_predict

(

Xi

)

for

Xi

in

X

]




3 效果評估



3.1 main函數


使用著名的波士頓房價數據集,按照7:3的比例拆分為訓練集和測試集,訓練模型,並統計準確度。



@

run_time


def main

()

:


    

print

(

"Tesing the accuracy of RegressionTree..."

)


    

# Load data


    

X

,

y

=

load_boston_house_prices

()


    

# Split data randomly, train set rate 70%


    

X_train

,

X_test

,

y_train

,

y_test

=

train_test_split

(


        

X

,

y

,

random_state

=

10

)


    

# Train model


    

reg

=

RegressionTree

()


    

reg

.

fit

(

X

=

X_train

,

y

=

y_train

,

max_depth

=

4

)


    

# Show rules


    

reg

.

print_rules

()


    

# Model accuracy


    

get_r2

(

reg

,

X_test

,

y_test

)




3.2 效果展示


最終生成了15條規則,擬合優度0.801,運行時間1.74秒,效果還算不錯~



3.3 工具函數


本人自定義了一些工具函數,可以在github上查看 

https://

github.com/tushushu/Imylu/blob/master/utils.py

 1. run_time – 測試函數運行時間 2. load_boston_house_prices – 載入波士頓房價數據 3. train_test_split – 拆分訓練集、測試機 4. get_r2 – 計算擬合優度


總結


回歸樹的原理:



損失最小化,平均值大法。 最佳行與列,效果頂呱呱。


回歸樹的實現:



一頓操作猛如虎,加減乘除二叉樹。


https://zhuanlan.zhihu.com/p/41688007




【關於作者】


李小文:先後從事過數據分析、數據挖掘工作,主要開發語言是Python,現任一家小型互聯網公司的演算法工程師。Github: 

https://gith

ub.com/tushu

shu




【關於投稿】




如果大家有原創好文投稿,請直接給公號發送留言。




① 留言格式:


【投稿】+《 文章標題》+ 文章鏈接

② 示例:


【投稿】

《不要自稱是程序員,我十多年的 IT 職場總結》:


http://blog.jobbole.com/94148/



③ 最後請附上您的個人簡介哈~




看完本文有收穫?請轉

發分享給更多人


關注「P

ython開發者」,提升Python技能



喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 Python開發者 的精彩文章:

面向對象:我相信你正在與我相遇的路上馬不停蹄

TAG:Python開發者 |