File: test_callback.py

package info (click to toggle)
optuna 4.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 4,784 kB
  • sloc: python: 40,634; sh: 97; makefile: 30
file content (34 lines) | stat: -rw-r--r-- 1,188 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from optuna.study import create_study
from optuna.study import Study
from optuna.terminator import BaseTerminator
from optuna.terminator import TerminatorCallback
from optuna.trial import TrialState


class _DeterministicTerminator(BaseTerminator):
    def __init__(self, termination_trial_number: int) -> None:
        self._termination_trial_number = termination_trial_number

    def should_terminate(self, study: Study) -> bool:
        trials = study.get_trials(states=[TrialState.COMPLETE])
        latest_number = max([t.number for t in trials])

        if latest_number >= self._termination_trial_number:
            return True
        else:
            return False


def test_terminator_callback_terminator() -> None:
    # This test case validates that the study is stopped when the `should_terminate` method of the
    # terminator returns `True` for the first time.
    termination_trial_number = 10

    callback = TerminatorCallback(
        terminator=_DeterministicTerminator(termination_trial_number),
    )

    study = create_study()
    study.optimize(lambda _: 0.0, callbacks=[callback], n_trials=100)

    assert len(study.trials) == termination_trial_number + 1