File: util.py

package info (click to toggle)
python-confection 1.0.0~dev0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 260 kB
  • sloc: python: 2,588; sh: 13; makefile: 4
file content (129 lines) | stat: -rw-r--r-- 3,475 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Registered functions used for config tests.
"""
import contextlib
import dataclasses
import shutil
import tempfile
from pathlib import Path
from typing import Generator, Generic, Iterable, List, Optional, TypeVar, Union

import catalogue
from pydantic.types import StrictBool  # type: ignore

import confection

FloatOrSeq = Union[float, List[float], Generator]
InT = TypeVar("InT")
OutT = TypeVar("OutT")


@dataclasses.dataclass
class Cat(Generic[InT, OutT]):
    name: str
    value_in: InT
    value_out: OutT


my_registry_namespace = "config_tests"


class my_registry(confection.registry):
    namespace = "config_tests"
    cats = catalogue.create(namespace, "cats", entry_points=False)
    optimizers = catalogue.create(namespace, "optimizers", entry_points=False)
    schedules = catalogue.create(namespace, "schedules", entry_points=False)
    initializers = catalogue.create(namespace, "initializers", entry_points=False)
    layers = catalogue.create(namespace, "layers", entry_points=False)


@my_registry.cats.register("catsie.v1")
def catsie_v1(evil: StrictBool, cute: bool = True) -> str:
    if evil:
        return "scratch!"
    else:
        return "meow"


@my_registry.cats.register("catsie.v2")
def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str:
    if evil:
        return "scratch!"
    else:
        if cute_level > 2:
            return "meow <3"
        return "meow"


@my_registry.cats("catsie.v3")
def catsie(arg: Cat) -> Cat:
    return arg


@my_registry.optimizers("Adam.v1")
def Adam(
    learn_rate: FloatOrSeq = 0.001,
    *,
    beta1: FloatOrSeq = 0.001,
    beta2: FloatOrSeq = 0.001,
    use_averages: bool = True,
):
    """
    Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used
    to illustrate how to use the function registry, e.g. with thinc.
    """

    @dataclasses.dataclass
    class Optimizer:
        learn_rate: FloatOrSeq
        beta1: FloatOrSeq
        beta2: FloatOrSeq
        use_averages: bool

    return Optimizer(
        learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages
    )


@my_registry.schedules("warmup_linear.v1")
def warmup_linear(
    initial_rate: float, warmup_steps: int, total_steps: int
) -> Iterable[float]:
    """Generate a series, starting from an initial rate, and then with a warmup
    period, and then a linear decline. Used for learning rates.
    """
    step = 0
    while True:
        if step < warmup_steps:
            factor = step / max(1, warmup_steps)
        else:
            factor = max(
                0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps)
            )
        yield factor * initial_rate
        step += 1


@my_registry.cats("int_cat.v1")
def int_cat(
    value_in: Optional[int] = None, value_out: Optional[int] = None
) -> Cat[Optional[int], Optional[int]]:
    """Instantiates cat with integer values."""
    return Cat(name="int_cat", value_in=value_in, value_out=value_out)


@my_registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: List[float], beta1: float):
    return Adam(learn_rate, beta1=beta1)


@my_registry.schedules("my_cool_repetitive_schedule.v1")
def decaying(base_rate: float, repeat: int) -> List[float]:
    return repeat * [base_rate]


@contextlib.contextmanager
def make_tempdir():
    d = Path(tempfile.mkdtemp())
    yield d
    shutil.rmtree(str(d))