早速、前回の記事でインストールした dtreeviz を使ってみます。
※この記事では dtreevizの version 0.8.2 を使っています。
1.0.0 では一部引数の名前などが違う様です。(X_train が x_dataになるなど。)
とりあえず、データと可視化する木がないと話にならないので、いつものirisで作っておきます。
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
clf = DecisionTreeClassifier(min_samples_split=5)
clf.fit(
iris.data,
iris.target
)
さて、これで学習したモデル(コード中のclf
)を可視化します。
リポジトリのコードを見ながらやってみます。
まず、一番シンプルな可視化は、 dtreeviz.trees.dtreeviz
にモデルと必要なデータを全部渡すものの様です。
(省略不可能な引数だけ設定して実行しましたが、結構多いですね。)
from dtreeviz.trees import dtreeviz
tree_viz = dtreeviz(
tree_model=clf,
X_train=iris.data,
y_train=iris.target,
feature_names=iris.feature_names,
target_name="types",
class_names=iris. target_names.tolist(),
)
tree_viz
出力がこちら。
graphvizで決定木を可視化 でやったのと比べて、とてもスタイリッシュで解釈しやすいですね。
orientation(デフォルトは’TD’)に’LR’を指定すると、向きを縦から横に変更できます。
tree_viz = dtreeviz(
tree_model=clf,
X_train=iris.data,
y_train=iris.target,
feature_names=iris.feature_names,
target_name="types",
class_names=iris. target_names.tolist(),
orientation='LR',
)
tree_viz
出力がこちら。
木のサイズによってはこれも選択肢に入りそうですね。