File: test_optimizers.py

package info (click to toggle)
python-thinc 8.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 5,804 kB
  • sloc: python: 15,818; javascript: 1,554; ansic: 342; makefile: 20; sh: 13
file content (99 lines) | stat: -rw-r--r-- 3,047 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pytest
from thinc.api import registry, Optimizer
import numpy


def _test_schedule_valid():
    while True:
        yield 0.456


def _test_schedule_invalid():
    yield from []


@pytest.fixture(
    params=[
        (lambda: 0.123, 0.123, 0.123, 0.123),
        (lambda: _test_schedule_valid(), 0.456, 0.456, 0.456),
        (lambda: (i for i in [0.2, 0.1, 0.4, 0.5, 0.6, 0.7, 0.8]), 0.2, 0.1, 0.4),
        (lambda: (i for i in [0.333, 0.666]), 0.333, 0.666, 0.666),
        (lambda: [0.9, 0.8, 0.7], 0.9, 0.8, 0.7),
        (lambda: [0.0, 0.123], 0.0, 0.123, 0.123),
    ],
    scope="function",
)
def schedule_valid(request):
    # Use lambda to prevent iterator from being consumed by first test
    r_func, r1, r2, r3 = request.param
    return r_func(), r1, r2, r3


@pytest.fixture(
    params=[
        (lambda: "hello"),
        (lambda: _test_schedule_invalid()),
        (lambda: (_ for _ in [])),
        (lambda: []),
    ],
    scope="function",
)
def schedule_invalid(request):
    # Use lambda to prevent iterator from being consumed by first test
    r_func = request.param
    return r_func()


@pytest.mark.parametrize("name", ["RAdam.v1", "Adam.v1", "SGD.v1"])
def test_optimizers_from_config(name):
    learn_rate = 0.123
    cfg = {"@optimizers": name, "learn_rate": learn_rate}
    optimizer = registry.resolve({"config": cfg})["config"]
    assert optimizer.learn_rate == learn_rate


def test_optimizer_schedules_from_config(schedule_valid):
    lr, lr_next1, lr_next2, lr_next3 = schedule_valid
    cfg = {"@optimizers": "Adam.v1", "learn_rate": lr}
    optimizer = registry.resolve({"cfg": cfg})["cfg"]
    assert optimizer.learn_rate == lr_next1
    optimizer.step_schedules()
    assert optimizer.learn_rate == lr_next2
    optimizer.step_schedules()
    assert optimizer.learn_rate == lr_next3
    optimizer.learn_rate = 1.0
    assert optimizer.learn_rate == 1.0


def test_optimizer_schedules_valid(schedule_valid):
    lr, lr_next1, lr_next2, lr_next3 = schedule_valid
    optimizer = Optimizer(learn_rate=lr)
    assert optimizer.learn_rate == lr_next1
    optimizer.step_schedules()
    assert optimizer.learn_rate == lr_next2
    optimizer.step_schedules()
    assert optimizer.learn_rate == lr_next3
    optimizer.learn_rate = 1.0
    assert optimizer.learn_rate == 1.0


def test_optimizer_schedules_invalid(schedule_invalid):
    with pytest.raises(ValueError):
        Optimizer(learn_rate=schedule_invalid)


def test_optimizer_init():
    optimizer = Optimizer(
        learn_rate=0.123,
        use_averages=False,
        use_radam=True,
        L2=0.1,
        L2_is_weight_decay=False,
    )
    _, gradient = optimizer((0, "x"), numpy.zeros((1, 2)), numpy.zeros(0))
    assert numpy.array_equal(gradient, numpy.zeros(0))
    W = numpy.asarray([1.0, 0.0, 0.0, 1.0], dtype="f").reshape((4,))
    dW = numpy.asarray([[-1.0, 0.0, 0.0, 1.0]], dtype="f").reshape((4,))
    optimizer((0, "x"), W, dW)
    optimizer = Optimizer(learn_rate=0.123, beta1=0.1, beta2=0.1)
    optimizer((1, "x"), W, dW)