File: test_callbacks.py

package info (click to toggle)
scikit-optimize 0.10.2-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,684 kB
  • sloc: python: 10,659; javascript: 438; makefile: 136; sh: 6
file content (132 lines) | stat: -rw-r--r-- 3,897 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
from collections import namedtuple

import numpy as np
import pytest
from numpy.testing import assert_almost_equal

from skopt import dummy_minimize, gp_minimize
from skopt.benchmarks import bench1, bench3
from skopt.callbacks import (
    CheckpointSaver,
    DeadlineStopper,
    DeltaYStopper,
    HollowIterationsStopper,
    StdStopper,
    ThresholdStopper,
    TimerCallback,
)
from skopt.utils import load


@pytest.mark.fast_test
def test_timer_callback():
    callback = TimerCallback()
    dummy_minimize(bench1, [(-1.0, 1.0)], callback=callback, n_calls=10)
    assert len(callback.iter_time) <= 10
    assert 0.0 <= sum(callback.iter_time)


@pytest.mark.fast_test
def test_deltay_stopper():
    deltay = DeltaYStopper(0.2, 3)

    Result = namedtuple('Result', ['func_vals'])

    assert deltay(Result([0, 1, 2, 3, 4, 0.1, 0.19]))
    assert not deltay(Result([0, 1, 2, 3, 4, 0.1]))
    assert deltay(Result([0, 1])) is None


@pytest.mark.fast_test
def test_threshold_stopper():
    threshold = ThresholdStopper(3.0)

    Result = namedtuple('Result', ['func_vals'])

    assert not threshold(Result([3.1, 4, 4.6, 100]))
    assert threshold(Result([3.0, 3, 2.9, 0, 0.0]))


@pytest.mark.fast_test
def test_std_stopper():
    std = StdStopper(0.35)
    result = gp_minimize(
        bench1, [(-1.0, 1.0)], callback=std, n_calls=50, random_state=1
    )
    assert_almost_equal(result.models[0].y_train_std_, std.threshold, decimal=1)


@pytest.mark.fast_test
def test_deadline_stopper():
    deadline = DeadlineStopper(0.0001)
    gp_minimize(bench3, [(-1.0, 1.0)], callback=deadline, n_calls=10, random_state=1)
    assert len(deadline.iter_time) >= 1
    assert np.sum(deadline.iter_time) > deadline.total_time

    deadline = DeadlineStopper(60)
    gp_minimize(bench3, [(-1.0, 1.0)], callback=deadline, n_calls=10, random_state=1)
    assert len(deadline.iter_time) >= 1
    assert np.sum(deadline.iter_time) < deadline.total_time


@pytest.mark.fast_test
def test_hollow_iterations_stopper():
    Result = namedtuple("Result", ["func_vals"])

    hollow = HollowIterationsStopper(3, 0)
    # will run at least n_iterations + 1 times
    assert not hollow(Result([10, 11, 12]))
    assert hollow(Result([10, 11, 12, 13]))

    # a tie is not enough
    assert hollow(Result([10, 11, 12, 10]))

    # every time we make a new min, we then have n_iterations rounds to beat it
    assert not hollow(Result([10, 9, 8, 7, 7, 7]))
    assert hollow(Result([10, 9, 8, 7, 7, 7, 7]))

    hollow = HollowIterationsStopper(3, 1.1)
    assert not hollow(Result([10, 11, 12, 8.89]))
    assert hollow(Result([10, 11, 12, 8.9]))

    # individual improvement below threshold contribute
    assert hollow(Result([10, 9.9, 9.8, 9.7]))
    assert not hollow(Result([10, 9.5, 9, 8.5, 8, 7.5]))

    hollow = HollowIterationsStopper(3, 0)
    result = gp_minimize(
        bench3, [(-1.0, 1.0)], callback=hollow, n_calls=100, random_state=1
    )
    assert len(result.func_vals) == 10

    hollow = HollowIterationsStopper(3, 0.1)
    result = gp_minimize(
        bench3, [(-1.0, 1.0)], callback=hollow, n_calls=100, random_state=1
    )
    assert len(result.func_vals) == 5

    hollow = HollowIterationsStopper(3, 0.2)
    result = gp_minimize(
        bench3, [(-1.0, 1.0)], callback=hollow, n_calls=100, random_state=1
    )
    assert len(result.func_vals) == 4


@pytest.mark.fast_test
def test_checkpoint_saver():
    checkpoint_path = "./test_checkpoint.pkl"

    if os.path.isfile(checkpoint_path):
        os.remove(checkpoint_path)

    checkpoint_saver = CheckpointSaver(checkpoint_path, compress=9)
    result = dummy_minimize(
        bench1, [(-1.0, 1.0)], callback=checkpoint_saver, n_calls=10
    )

    assert os.path.exists(checkpoint_path)
    assert load(checkpoint_path).x == result.x

    if os.path.isfile(checkpoint_path):
        os.remove(checkpoint_path)