plot_roc_curveを試す

今回もscikit-learn 0.22.0 の新機能を試します。
今回はROC曲線を書いてくれる、plot_roc_curveです。
ドキュメント: sklearn.metrics.plot_roc_curve

元々、 roc_curve という機能はあったのですが、これを可視化するのは少しだけ面倒だったので結構期待していました。

ただ、試してみたところ、plot_roc_curveは2値分類のモデルにしか対応していないようです。
roc_curveは引数のpos_labelに対象のラベルを指定してあげれば多値分類にも対応しているので、今後に期待します。

ということで、2値分類のダミーデータを作成して、試してみます。
綺麗に分類できる問題だとROC曲線の価値がよくわからないので線形分離不可能な問題にしています。
最後にグラフを二つ出力していますが、左のが、生成したテストデータとモデルの決定境界、右側が、ROC曲線です。


from sklearn.metrics import plot_roc_curve
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

# 綺麗に線形分類することのできないダミーデータの生成。
X0 = np.clip(2 * np.random.randn(500, 2) + np.array([1, 1]), -7, 7)
X1 = np.clip(2 * np.random.randn(500, 2) + np.array([-1, -1]), -7, 7)
X = np.concatenate([X0, X1])
y = np.array([0]*500 + [1]*500)

# 学習データとテストデータに分ける
X_train, X_test, y_train, y_test = train_test_split(
                                    X,
                                    y,
                                    test_size=0.2,
                                    stratify=y,
                                )
# モデルの作成と学習
clf = LogisticRegression()
clf.fit(X_train, y_train)

# ここから可視化
fig = plt.figure(figsize=(12, 6), facecolor="w")
ax = fig.add_subplot(1, 2, 1, aspect='equal', xlim=(-7, 7), ylim=(-7, 7))
ax.set_title("テストデータと決定境界")

# データを散布図で表示する
ax.scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1], alpha=0.7)
ax.scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1], alpha=0.7)

# 決定境界を可視化する
X_mesh, Y_mesh = np.meshgrid(np.linspace(-7, 7, 401), np.linspace(-7, 7, 401))
Z_mesh = clf.predict(np.array([X_mesh.ravel(), Y_mesh.ravel()]).T)
Z_mesh = Z_mesh.reshape(X_mesh.shape)
ax.contourf(X_mesh, Y_mesh, Z_mesh, alpha=0.1)

ax = fig.add_subplot(1, 2, 2, aspect='equal', xlim=(0, 1), ylim=(0, 1))
ax.set_title("ROC曲線",)
plot_roc_curve(clf, X_test, y_test, ax=ax)

plt.show()

出力された図がこちら。

コード全体が長いのでわかりにくくて恐縮ですが、 ROC曲線自体は1行で出力できており、非常に手軽です。
また、AUCも同時に出力してくれています。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です