scikit-learn でグリッドサーチ

機械学習のハイパーパラメーターを決定するとき、グリッドサーチという手法を使うことがあります。
よほど学習時にかかるケース以外では、ほぼ確実に行なっています。

そのとき、scikit-learn の GridSearchCV というクラスを使うことが多いのでその使い方をメモしておきます。
今回は題材として、 digits という手書き数字のデータセットを利用します。


from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

最初にデータを準備します。
# データの読み込み
digits = load_digits()
X = digits.data
y = digits.target
# 訓練データとテストデータに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

次にサーチするパラメーターを指定します。


# グリッドサーチするパラメーターを指定。変数名と値のリストの辞書 、それが複数ある場合はそれらの配列。
param_grid = [
    {
        'C': [1, 10, 100, 1000],
        'kernel': ['linear']
    },
    {
        'C': [0.1, 1, 10, 100, 1000],
        'kernel': ['rbf'],
        'gamma': [0.001, 0.0001, 'auto']
    },
    {
        'C': [0.1, 1, 10, 100, 1000],
        'kernel': ['poly'], 'degree': [2, 3, 4],
        'gamma': [0.001, 0.0001, 'auto']
    },
    {
        'C': [0.1, 1, 10, 100, 1000],
        'kernel':['sigmoid'],
        'gamma': [0.001, 0.0001, 'auto']
    }
]

モデルを作って、グリッドサーチの実行


from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

# モデルの準備
model = GridSearchCV(
    SVC(),  # 予測機
    param_grid,  # サーチ対象のパラメーター
    cv=5,  # 交差検証の数
    # このほか、評価指標(scoring) や、パラレル実行するJob数なども指定可能(n_jobs)
)
# グリッドサーチの実行
model.fit(X_train, y_train)

最良のパラメーターを確認する


print(model.best_params_)

# 出力
{'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}

最後に、テスト用に取っておいたデータで、出来上がったモデルを評価します。


from sklearn.metrics import classification_report

# 学習したモデルで予測
y_predict = model.predict(X_test)
# 作成したモデルの評価
print(classification_report(y_test, y_predict))

# 出力
             precision    recall  f1-score   support

          0       1.00      1.00      1.00        38
          1       0.97      1.00      0.99        34
          2       1.00      1.00      1.00        38
          3       0.97      1.00      0.99        34
          4       1.00      1.00      1.00        36
          5       1.00      0.97      0.99        35
          6       0.97      0.97      0.97        39
          7       0.97      1.00      0.98        31
          8       0.97      0.97      0.97        38
          9       1.00      0.95      0.97        37

avg / total       0.99      0.99      0.99       360

なかなか良い正解率ですね。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です