ジニ不純度について

ここ数週間にわたってエントロピー関連の話が続きましたので、ついでみたいになるのですが決定木の評価関数として同じようによく使われるジニ不純度についても紹介しておきます。
ジニ係数(はじパタなど)や、ジニ指数(カステラ本など)といった呼び名もあるのですが、経済学で所得格差の指標に用いられるジニ係数と紛らわしくなるので、この記事ではジニ不純度と呼びます。

ジニ不純度の定義

ジニ不純度は、決定木やランダムフォレストなどの機械学習のアルゴリズムにおいて、ノードの純度を計測するための指標の一つとして用いられます。要するに、そのノードに分類される観測値のクラスの内訳が全部同じだとなると不純度は小さく、同じノードに複数のクラスの観測値が混在してたら不純度は高くなります。エントロピーと似た挙動ですね。

数式としては次のように定義されます。$n$クラスに分類する問題の場合、$i$番目のクラスのサンプル比率を$p_i$とすると、ジニ不純度は次のように定義されます。

$$I = 1 – \sum_{i=1}^n p_i^2.$$

例えば、3クラスの分類問題を考えると、あるノードの観測値が全部同じクラスだったとします。すると、$p_1 = 1, p_2=0, p_3 = 0$みたいになるのですが、この時$I=0$と、不純度が0になります。また、3クラスが均等に混ざっていた場合、$p_1 = p_2 = p_3 = 2/3$となり、時に不純度は$I = 2/3$となります。$n$クラスの分類問題の場合、この$1-1/n$が最大値です。

もう7年近く前になりますが、初めてこの定義式を見た時は、1から確率の平方和を引くことに対してちょっと違和感を感じました。ただ、これは計算を効率的にやるためにこの形にしているだけで、実際は次の形で定義される値と考えておくと納得しやすいです。

$$I = \sum_{i=1}^n p_i (1-p_i).$$

これ展開すると、$\sum_{i=1}^n p_i – p_i^2$となり、$\sum_{i=1}^n p_i =1$ですから、先に出した定義式と一致することがわかります。

ジニ不純度の解釈

それでは、$p_i(1-p_i)$とは何か、という話なのですが解釈が二つあります。

まず一つ目。ある観測値がクラスiであれば1、クラスi出なければ0というベルヌーイ試行を考えるとその分散がこの$p_i(1-p_i)$になります。これを各クラス分足し合わせたのがジニ不純度ですね。

クラスiである確率が0の場合と1の2種類の場合がまじりっけ無しで不純度0ということでとてもわかりやすいです。

もう一つは、そのノードにおける誤分類率です。あるノードにおける観測値を、それぞれ確率$p_i$で、クラス$i$である、とジャッジすると、それが本当はクラス$i$ではない確率が$(1-p_i)$ありますから、誤分類率は$\sum_{i=1}^n p_i (1-p_i)$となります。

あるノードの観測値のクラスが全部一緒なら誤分類は発生しませんし、$n$クラスが均等に含まれていたら誤分類率は最大になりますね。不純度の定義としてわかりやすいです。

エントロピーと比較した場合のメリット

決定木の分割の評価として使う場合、エントロピーとジニ不純度は少しだけ異なる振る舞いをし、一長一短あるのですが、ジニ不純度には計算効率の面で優位性があるとも聞きます。というのも、$p\log{p}$より$p^2$のほうが計算が簡単なんですよね。scikit-learnがginiの方をデフォルトにしているのはこれが理由じゃないかなと思ってます。(個人的な予想ですが。)

ただ、実際に実験してみると計算速度はそこまで大差はないです(np.log2の実装は優秀)。なんなら自分で書いて試したらエントロピーの方が速かったりもします。とはいえ、$p\log{p}$の計算はちゃんとやるなら、$p=0$ではないことをチェックして場合分けを入れたりしないといけないのでそこまできちんと作ったらジニ不純度の方が速くなりそうです。

Pythonでの実装

数式面の説明が続きましたが、最後に実装の話です。実は主要ライブラリではジニ不純度を計算する便利な関数等は提供されていません。まぁ、簡単なので自分で書けば良いのですが。

scikit-learnのソースコードを見ると、ここで実装されて内部で使われていますね。

確率$p$のリストが得られたら、全部二乗して足して1から引けば良いです。
$1/6+1/3+1/2=1$なので、確率のリストは$[1/6, 1/3, 1/2]$としてやってみます。

import numpy as np


p_list = np.array([1/6, 1/3, 1/2])
print(1-(p_list**2).sum())
# 0.6111111111111112

まぁ、確かにこれならわざわざライブラリ用意してもらわなくてもって感じですね。

割合ではなく要素数でデータがある場合は全体を合計で割って割合にすればよいだけですし。

以上で、ジニ不純度の説明を終えます。エントロピーと比べて一長一短あるのでお好みの方を使ってみてください。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です