1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
|
"""
.. _optuna_callback:
Callback for Study.optimize
===========================
This tutorial showcases how to use & implement Optuna ``Callback`` for :func:`~optuna.study.Study.optimize`.
``Callback`` is called after every evaluation of ``objective``, and
it takes :class:`~optuna.study.Study` and :class:`~optuna.trial.FrozenTrial` as arguments, and does some work.
`MLflowCallback <https://optuna-integration.readthedocs.io/en/stable/reference/generated/optuna_integration.MLflowCallback.html>`__ is a great example.
"""
###################################################################################################
# Stop optimization after some trials are pruned in a row
# -------------------------------------------------------
#
# This example implements a stateful callback which stops the optimization
# if a certain number of trials are pruned in a row.
# The number of trials pruned in a row is specified by ``threshold``.
import optuna
class StopWhenTrialKeepBeingPrunedCallback:
def __init__(self, threshold: int):
self.threshold = threshold
self._consequtive_pruned_count = 0
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
if trial.state == optuna.trial.TrialState.PRUNED:
self._consequtive_pruned_count += 1
else:
self._consequtive_pruned_count = 0
if self._consequtive_pruned_count >= self.threshold:
study.stop()
###################################################################################################
# This objective prunes all the trials except for the first 5 trials (``trial.number`` starts with 0).
def objective(trial):
if trial.number > 4:
raise optuna.TrialPruned
return trial.suggest_float("x", 0, 1)
###################################################################################################
# Here, we set the threshold to ``2``: optimization finishes once two trials are pruned in a row.
# So, we expect this study to stop after 7 trials.
import logging
import sys
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])
###################################################################################################
# As you can see in the log above, the study stopped after 7 trials as expected.
|