逼真,特別逼真的決策樹可視化

  • 2022 年 1 月 15 日
  • 筆記

同學們好,決策樹的可視化,我以為之前介紹的方法已經夠驚艷了(決策樹可視化,被驚艷到了!),沒想到最近又發現了一個更驚艷的,而且更逼真,話不多說,先看效果圖↓

直接繪製隨機森林也不在話下

下面就向大家介紹一下這個神器 —— pybaobabdt的極簡入門用法

安裝GraphViz

pybaobabdt依賴GraphViz,首先下載安裝包

//www.graphviz.org/download/

//www.graphviz.org/download/

2、雙擊msi文件,然後一直選擇next(默認安裝路徑為C:\Program Files (x86)\Graphviz2.38\),安裝完成之後,會在windows開始菜單創建快捷資訊。

3、配置環境變數:電腦→屬性→高級系統設置→高級→環境變數→系統變數→path,在path中加入路徑:

4、驗證:在windows命令行介面,輸入dot -version,然後按回車,如果顯示如下圖所示的graphviz相關版本資訊,則安裝配置成功。

安裝pygraphviz和pybaobabdt

pip直接安裝pygraphviz的話,大概率會報錯,建議下載whl文件本地安裝。

//www.lfd.uci.edu/~gohlke/pythonlibs/#pygraphviz

pybaobabdt就簡單了,直接pip install pybaobabdt 即可

pybaobabdt用法

pybaobabdt 用起來也簡單到離譜,核心命令只有一個pybaobabdt.drawTree,下面是官方文檔示例程式碼,建議在jupyter-notebook中運行。

import pybaobabdt
import pandas as pd
from scipy.io import arff
from sklearn.tree import DecisionTreeClassifier
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
from colour import Color
import matplotlib.pyplot as plt
import numpy as np

data = arff.loadarff('vehicle.arff')
df   = pd.DataFrame(data[0])
y = list(df['class'])
features = list(df.columns)

                   
features.remove('class')
X = df.loc[:, features]

clf = DecisionTreeClassifier().fit(X, y)

ax = pybaobabdt.drawTree(clf, size=10, dpi=72, features=features, colormap='Spectral')

這個圖怎麼看呢?

不同的顏色對應不同的分類(target),每個分叉處都標記了分裂的條件,所以劃分邏輯一目了然。 樹的深度也是工整的體現了出來。

樹枝的直徑也不是擺設,而是代表了樣本的個數(比例),該劃分條件下的樣本越多,樹榦也就越粗。

你是發現最最底層的樹枝太細太脆弱的時候,是不是應該考慮一下過擬合風險,比如需要調整一下最小樣本數?

繪製隨機森林

import pybaobabdt
import pandas as pd
from scipy.io import arff
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
data = arff.loadarff('vehicle.arff')

df = pd.DataFrame(data[0])
y = list(df['class'])
features = list(df.columns)
features.remove('class')
X = df.loc[:, features]

clf = RandomForestClassifier(n_estimators=20, n_jobs=-1, random_state=0)
clf.fit(X, y)
size = (15,15)
plt.rcParams['figure.figsize'] = size
fig = plt.figure(figsize=size, dpi=300)

for idx, tree in enumerate(clf.estimators_):
    ax1 = fig.add_subplot(5, 4, idx+1)
    pybaobabdt.drawTree(tree, model=clf, size=15, dpi=300, features=features, ax=ax1)
    
fig.savefig('random-forest.png', format='png', dpi=300, transparent=True)

怎麼用,是不是很酷,趕緊去試試吧!
如有收穫,可否在看、收藏、轉發一下?感謝~

//mp.weixin.qq.com/s/uIazCL9SjNDguu59up5KjA