論文精讀:Deep Neural Decision Trees

Deep Neural Decision Trees

Soft binning function

Soft binning function 這個函數的功能爲:輸入一個標量 x x x ,生成標量 x x x 屬於的區間的索引。具體如何實現的呢?往下看:

假設我們有一個連續的變量 x x x,我們想把它分隔成 n + 1 n+1 n+1 個間隔。這樣就需要 n n n 個切割點(cut points),這 n n n 個切割點是可以訓練的變量。將 n n n 個切割點記做 [ β 1 , β 2 , . . . , β n ] [β_1, β_2, . . . , β_n] [β1,β2,...,βn],並且 β 1 < β 2 < ⋅ ⋅ ⋅ < β n . β_1 < β_2 < ··· < β_n. β1<β2<<βn.

我們用 Softmax 作爲**函數構造一個單層神經網絡:
π = f w , b , τ ( x ) = s o f t m a x ( ( w x + b ) / τ ) π = f_{w,b,τ}(x) = softmax((wx + b)/τ ) π=fw,b,τ(x)=softmax((wx+b)/τ)
這裏的 w w w 是常量而不是可以訓練的變量。將 w w w 的值記爲: $w = [1, 2, . . . , n + 1]. $ b b b 記作:
b = [ 0 , − β 1 , − β 1 − β 2 , . . . , − β 1 − β 2 − ⋅ ⋅ ⋅ − β n ] . b=[0,−β_1,−β_1 −β_2,...,−β_1 −β_2 −···−β_n]. b=[0,β1,β1β2,...,β1β2βn].
並且 $ τ > 0$ 是一個係數. 當 τ → 0 τ → 0 τ0 時輸出趨向於一個 one-hot 向量。

舉個栗子:假設有三個連續的 logits : o i − 1 , o i , o i + 1 o_{i−1}, o_{i}, o_{i+1} oi1,oi,oi+1 , 當同時滿足 o i > o i − 1 o_{i} > o_{i−1} oi>oi1 (即 x > β i x > β_i x>βi) 和 $ o_i > o_{i+1}$ (即 $ x < β_{i+1}$), x x x 就一定落在 ( β i , β i + 1 ) (β_i , β_{i+1} ) (βi,βi+1) 範圍內。

比如我們有一個範圍爲 [ 0 , 1 ] [0,1] [01] 的標量 x x x,兩個切割爲在 0.33 0.33 0.33 0.66 0.66 0.66,即 β 1 = 0.33 , β 2 = 0.66 β_1=0.33, β_2=0.66 β1=0.33,β2=0.66。那麼根據上面兩個公式可得到三個 logits: o 1 = x , o 2 = 2 x − 0.33 , o 3 = 3 x − 0.99 o_{1}=x, o_{2}=2x-0.33, o_{3}=3x-0.99 o1=x,o2=2x0.33,o3=3x0.99 。如果 o 2 > o 1 o_{2} > o_{1} o2>o1 那麼 2 x − 0.33 > x 2x-0.33 > x 2x0.33>x x > β 1 = 0.33 x > β_1 = 0.33 x>β1=0.33, 如果 o 2 > o 3 o_{2} > o_{3} o2>o3 那麼 2 x − 0.33 > 3 x − 0.99 2x-0.33 > 3x - 0.99 2x0.33>3x0.99 x < ( 0.99 − 0.33 ) = ( β 2 − 0.33 ) x < (0.99 - 0.33) = (β_{2} - 0.33) x<(0.990.33)=(β20.33)。這樣的話,當滿足 o 2 > o 1 o_{2} > o_{1} o2>o1 o 2 > o 3 o_{2} > o_{3} o2>o3 時, x x x 落在區間 ( β 1 , β 2 ) (β_1 , β_{2} ) (β1,β2) 內。

下圖可以看到 Soft binning function 的函數曲線:

QQ20201103-211105@2x

x x x 軸是連續輸入變量 x ∈ [ 0 , 1 ] x∈[0,1] x[01] 的值。左上:logits 的原始值;右上:應用 τ = 1 τ= 1 τ=1 的 Softmax 函數後的值;左下: τ = 0.1 τ= 0.1 τ=0.1 ;右下: τ = 0.01 τ= 0.01 τ=0.01

通過上圖中的左下可以得知,如果 x = 0.15 x = 0.15 x=0.15 ,此時 o 1 > o 2 > o 3 o_1 > o_2 > o_3 o1>o2>o3 ,那麼 x > 2 x − 0.33 x > 2x - 0.33 x>2x0.33 0.33 > x 0.33 > x 0.33>x,那麼落在了第一個切割點 β 1 β_1 β1 的左面,同理有了這三個曲線,我們就能比較它們在 x x x 取不同值的時候的大小,這樣就能確定它們位於哪些分隔點之間。

所以使用這個函數就能根據輸入的 x x x 生成近似於 one-hot 的向量,尤其是在小 τ τ τ 時。看上圖的右下角, τ τ τ 越來越小的時候函數變得非常置信。當 x x x 在區間 0.4 − 0.6 0.4 - 0.6 0.40.6 區間時, [ o 1 , o 2 , o 3 ] [o_1, o_2, o_3] [o1,o2,o3] 近似等於 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0], 同理在區間 0.0 − 0.4 0.0 - 0.4 0.00.4 時爲 [ 1 , 0 , 0 ] [1, 0, 0] [1,0,0], 在區間 0.6 − 1.0 0.6 - 1.0 0.61.0 區間時爲 [ 0 , 0 , 1 ] [0, 0, 1] [0,0,1]。這樣的話就能和 Soft binning function 這個函數的功能對應上了:輸入一個標量 x x x ,生成標量 x x x 屬於的區間的索引。

Construct decision tree

有了 binning function,那麼還需要用到 Kronecker product ⊗ ⊗ 操作。下圖是一個 Kronecker product 的例子:

QQ20201103-231459@2x

假設我們有一個實例輸入 x ∈ R D x ∈ R^D xRD D D D 個特徵。對 D D D 個特徵中每一個特徵 x d x_d xd 都進行 binning function 操作通過自己的 neural network f d ( x d ) f_d(x_d) fd(xd),這樣我們就能查找最終的節點通過 Kronecker product 操作:
z = f 1 ( x 1 ) ⊗ f 2 ( x 2 ) ⊗ ⋅ ⋅ ⋅ ⊗ f D ( x D ) . z = f_1(x_1) ⊗ f_2(x_2) ⊗ · · · ⊗ f_D(x_D). z=f1(x1)f2(x2)fD(xD).
這裏的 z z z 也近似是一個 one-hot 向量來指代 x x x 到達的葉子節點的索引。最後假設每個葉子 z z z 處都有一個線性分類器用來分類到達這裏的實例。

下圖是在 Iris 數據集上學習到的 DNDT(只用了兩個特徵:Petal Length 和 Petal Width),其中紅色的字體指代的是可以訓練的參數,而黑色的字體是常量。下面是訓練後的結果,我門使用這顆樹進行預測。

2

假設我們有一個新的數據 P e t a l   L e n g t h = 3 , P e t a l   W i d t h = 2 Petal \ Length = 3, Petal \ Width = 2 Petal Length=3,Petal Width=2 輸入到下面這顆學習好的神經網絡決策樹中。計算的流程如下:
f 1 ( 3 ) = s o f t m a x ( ( [ 1 , 2 ] ⋅ 3 + [ 0 , − 2.58 ] ) / τ ) = s o f t m a x ( ( [ 3 , 3.42 ] ) / τ ) f_{1}(3) = softmax(([1, 2]\cdot3 + [0, -2.58])/τ) \\=softmax(([3, 3.42])/τ) f1(3)=softmax(([1,2]3+[0,2.58])/τ)=softmax(([3,3.42])/τ)
τ τ τ 很小的時候, f 1 ( 3 ) f_{1}(3) f1(3) 近似於一個 one-hot 向量 [ 0 , 1 ] [0, 1] [0,1]。同理可得到 f 2 ( 2 ) ≈ [ 0 , 1 ] f_{2}(2) \approx [0, 1] f2(2)[0,1]。使用公式 z = f 1 ( 3 ) ⊗ f 2 ( 2 ) = [ 0 , 1 ] ⊗ [ 0 , 1 ] = [ 0 , 0 , 0 , 1 ] z = f_1(3) ⊗ f_2(2) = [0,1] ⊗ [0,1]=[0,0,0,1] z=f1(3)f2(2)=[0,1][0,1]=[0,0,0,1]

得到的 Kron Product 結果放入一個分類期,得到分類結果:
z ⋅ W = [ 0 , 0 , 0 , 1 ] ⋅ [ [ . . . ] , [ . . . ] , [ . . . ] , [ − 3.24 , − 2.51 , 6.56 ] ] = [ − 3.24 , − 2.51 , 6.56 ] z \cdot W=[0,0,0,1] \cdot [[...],[...],[...],[-3.24,-2.51,6.56]] \\ = [-3.24, -2.51, 6.56] zW=[0,0,0,1][[...],[...],[...],[3.24,2.51,6.56]]=[3.24,2.51,6.56]
對於向量 [ − 3.24 , − 2.51 , 6.56 ] [-3.24, -2.51, 6.56] [3.24,2.51,6.56] 索引位置 3 上的值是最大的,此時可以判斷這個新數據是第三分類,即 Virginica。

下圖是普通決策樹構建的過程。

3

Learning the Tree

現在我們知道了如何找到輸入實例的路徑,並且分類它。那麼訓練的時候就需要訓練 cut points 和 leaf classifiers。但是由於神經網絡 mini-batch 風格的訓練,DNDT 可以很好地擴展實例的數量。但是,到目前爲止,該設計的一個關鍵缺點是:由於使用了Kronecker Product,因此就 feature 數量而言無法擴展。在我們目前的實現中,我們通過訓練具有隨機子空間的森林來避免「寬」數據集的問題 - 但這會以可解釋性爲代價。

也就是說訓練多棵樹,每棵樹的訓練基於所有特徵的子集合。子集合的選取是隨機的,這樣就能通過多棵樹把所有特徵都考慮了,這樣就能變向的解決 「寬」 數據集的問題。

更好的解決方案(可以不借助不可解釋的森林的方案)是在訓練過程中探索最後 binning function 結果的的稀疏性:非空葉的數量增長比葉總數慢得多。

Experiments

代碼:DNDT

下面是實驗基於的數據集:

1

對於 BaseLine 模型決策樹(DT),我們將兩個關鍵超參數設置爲「 gini」,將分割器設置爲「 best」。 對於神經網絡(NN),我們對所有數據集使用兩個包含 50 個神經元的隱藏層的體系結構。 DNDT 還具有一個超參數,即每個要素的切點數量(分支因子),對於數據集,我們將其設置爲1。

對於具有 12 個以上特徵的數據集,我們使用 DNDT 的 ensemble 版本,其中每棵樹隨機選擇 10 個特徵,總共有 10 棵樹。 最終的預測是由多數投票給出的。

Results

下面是在這些數據集上三種模型的表現。總體而言,性能最好的模型是 DT。 DT 的良好性能不足爲奇,因爲這些數據集主要是表格形式的,並且特徵維相對較低

傳統上,神經網絡在此類數據上沒有明顯的優勢。 但是,DNDT 略優於普通神經網絡,因爲它在設計上更接近決策樹。 當然,這只是一個指示性的結果,因爲所有這些模型都具有可調整的超參數。 然而,有趣的是,沒有任何一種模型具有主導優勢。 這讓人想起沒有免費的午餐定理。

啥是沒有免費的午餐定理咱也不知道,可能大佬寫文章就喜歡整這些文縐縐的東西吧。

2