Optunaで枝刈りをやってみる

自分がいつも使っているhpyeroptではなく、新たにOptunaを覚えて使ってみようと思ったメインの目的が、枝刈りの機能です。
これは要するに、学習の途中で見込みのないハイパーパラメーターは処理を打ち切ってしまって、
短時間で効率的に探索を進めようという機能です。

チュートリアルとしてはこちらにコードのサンプルがあります。
Pruning Unpromising Trials

また、打ち切りに使う基準には複数のアルゴリズムが使え、こちらのページにリファレンスがあります。
Pruners

では早速やってみましょう。コードはチュートリアルの例をベースに少し書き直しました。


from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
import optuna

iris = load_iris()
classes = list(set(iris.target))

train_x, test_x, train_y, test_y = \
    train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)


def objective(trial):

    alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)
    clf = SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(test_x, test_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            print("step", step, "で打ち切り")  # 何回めのエポックで打ち切ったか見るために追加
            raise optuna.structs.TrialPruned()
    return 1.0 - clf.score(test_x, test_y)


# Set up the median stopping rule as the pruning condition.
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
print(study.best_params)
print(study.best_value)

出力は次のようになります。


[I 2019-10-09 00:45:38,562] Finished trial#0 resulted in value: 0.368421052631579. Current best value is 0.368421052631579 with parameters: {'alpha': 0.0002196017670543267}.
[I 2019-10-09 00:45:38,757] Finished trial#1 resulted in value: 0.10526315789473684. Current best value is 0.10526315789473684 with parameters: {'alpha': 0.0006773222557376204}.
[I 2019-10-09 00:45:38,967] Finished trial#2 resulted in value: 0.39473684210526316. Current best value is 0.10526315789473684 with parameters: {'alpha': 0.0006773222557376204}.
[I 2019-10-09 00:45:39,201] Finished trial#3 resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:39,462] Finished trial#4 resulted in value: 0.3421052631578947. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:39,758] Finished trial#5 resulted in value: 0.3157894736842105. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:40,094] Finished trial#6 resulted in value: 0.052631578947368474. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
step 4 で打ち切り
[I 2019-10-09 00:45:40,126] Setting status of trial#7 as TrialState.PRUNED. 
step 1 で打ち切り
[I 2019-10-09 00:45:40,211] Setting status of trial#8 as TrialState.PRUNED. 
[I 2019-10-09 00:45:40,625] Finished trial#9 resulted in value: 0.39473684210526316. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:41,195] Finished trial#10 resulted in value: 0.07894736842105265. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:41,675] Finished trial#11 resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:42,132] Finished trial#12 resulted in value: 0.23684210526315785. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:42,605] Finished trial#13 resulted in value: 0.3157894736842105. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
step 11 で打ち切り
[I 2019-10-09 00:45:42,691] Setting status of trial#14 as TrialState.PRUNED. 
[I 2019-10-09 00:45:43,242] Finished trial#15 resulted in value: 0.07894736842105265. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:43,894] Finished trial#16 resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
step 1 で打ち切り
[I 2019-10-09 00:45:43,929] Setting status of trial#17 as TrialState.PRUNED. 
step 1 で打ち切り
[I 2019-10-09 00:45:44,067] Setting status of trial#18 as TrialState.PRUNED. 
[I 2019-10-09 00:45:44,756] Finished trial#19 resulted in value: 0.1842105263157895. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
{'alpha': 0.0037602050725428606}
0.02631578947368418

〜 as TrialState.PRUNED. とメッセージが出てるのを見るとわかる通り、結構な頻度で早い段階で打ち切られています。
alpha や intermediate_value の値も随時print出力すると、挙動の理解が深まるのでおおすすめです。

さて、期待の時間短縮効果ですが、試してみたところこの例ではほとんどないどころか、余計に時間がかかるようでした。
コード中の以下の部分をコメントアウトして実行すると、もっと早く探索が終わります。


        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(test_x, test_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            print("step", step, "で打ち切り")  # 何回めのエポックで打ち切ったか見るために追加
            raise optuna.structs.TrialPruned()

時間が余計にかかるようになってしまった原因は、
intermediate_value = 1.0 – clf.score(test_x, test_y)
の部分で、評価を頻繁に行っているせいだと思われます。

partial_fit にもっと長時間かかるサンプルであればきっと時短効果が得られると思うので、
次はそういう例で試そうと思います。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です