

Pruning Unpromising Trials



from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
import optuna

iris = load_iris()
classes = list(set(iris.target))

train_x, test_x, train_y, test_y = \
    train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)

def objective(trial):

    alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)
    clf = SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(test_x, test_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            print("step", step, "で打ち切り")  # 何回めのエポックで打ち切ったか見るために追加
            raise optuna.structs.TrialPruned()
    return 1.0 - clf.score(test_x, test_y)

# Set up the median stopping rule as the pruning condition.
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)


[I 2019-10-09 00:45:38,562] Finished trial#0 resulted in value: 0.368421052631579. Current best value is 0.368421052631579 with parameters: {'alpha': 0.0002196017670543267}.
[I 2019-10-09 00:45:38,757] Finished trial#1 resulted in value: 0.10526315789473684. Current best value is 0.10526315789473684 with parameters: {'alpha': 0.0006773222557376204}.
[I 2019-10-09 00:45:38,967] Finished trial#2 resulted in value: 0.39473684210526316. Current best value is 0.10526315789473684 with parameters: {'alpha': 0.0006773222557376204}.
[I 2019-10-09 00:45:39,201] Finished trial#3 resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:39,462] Finished trial#4 resulted in value: 0.3421052631578947. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:39,758] Finished trial#5 resulted in value: 0.3157894736842105. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:40,094] Finished trial#6 resulted in value: 0.052631578947368474. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
step 4 で打ち切り
[I 2019-10-09 00:45:40,126] Setting status of trial#7 as TrialState.PRUNED. 
step 1 で打ち切り
[I 2019-10-09 00:45:40,211] Setting status of trial#8 as TrialState.PRUNED. 
[I 2019-10-09 00:45:40,625] Finished trial#9 resulted in value: 0.39473684210526316. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:41,195] Finished trial#10 resulted in value: 0.07894736842105265. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:41,675] Finished trial#11 resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:42,132] Finished trial#12 resulted in value: 0.23684210526315785. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:42,605] Finished trial#13 resulted in value: 0.3157894736842105. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
step 11 で打ち切り
[I 2019-10-09 00:45:42,691] Setting status of trial#14 as TrialState.PRUNED. 
[I 2019-10-09 00:45:43,242] Finished trial#15 resulted in value: 0.07894736842105265. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
[I 2019-10-09 00:45:43,894] Finished trial#16 resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
step 1 で打ち切り
[I 2019-10-09 00:45:43,929] Setting status of trial#17 as TrialState.PRUNED. 
step 1 で打ち切り
[I 2019-10-09 00:45:44,067] Setting status of trial#18 as TrialState.PRUNED. 
[I 2019-10-09 00:45:44,756] Finished trial#19 resulted in value: 0.1842105263157895. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.0037602050725428606}.
{'alpha': 0.0037602050725428606}

〜 as TrialState.PRUNED. とメッセージが出てるのを見るとわかる通り、結構な頻度で早い段階で打ち切られています。
alpha や intermediate_value の値も随時print出力すると、挙動の理解が深まるのでおおすすめです。


        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(test_x, test_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            print("step", step, "で打ち切り")  # 何回めのエポックで打ち切ったか見るために追加
            raise optuna.structs.TrialPruned()

intermediate_value = 1.0 – clf.score(test_x, test_y)

partial_fit にもっと長時間かかるサンプルであればきっと時短効果が得られると思うので、


メールアドレスが公開されることはありません。 が付いている欄は必須項目です