1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
"""
Registered functions used for config tests.
"""
import contextlib
import dataclasses
import shutil
import tempfile
from pathlib import Path
from typing import Generator, Generic, Iterable, List, Optional, TypeVar, Union
import catalogue
from pydantic.types import StrictBool # type: ignore
import confection
FloatOrSeq = Union[float, List[float], Generator]
InT = TypeVar("InT")
OutT = TypeVar("OutT")
@dataclasses.dataclass
class Cat(Generic[InT, OutT]):
name: str
value_in: InT
value_out: OutT
my_registry_namespace = "config_tests"
class my_registry(confection.registry):
namespace = "config_tests"
cats = catalogue.create(namespace, "cats", entry_points=False)
optimizers = catalogue.create(namespace, "optimizers", entry_points=False)
schedules = catalogue.create(namespace, "schedules", entry_points=False)
initializers = catalogue.create(namespace, "initializers", entry_points=False)
layers = catalogue.create(namespace, "layers", entry_points=False)
@my_registry.cats.register("catsie.v1")
def catsie_v1(evil: StrictBool, cute: bool = True) -> str:
if evil:
return "scratch!"
else:
return "meow"
@my_registry.cats.register("catsie.v2")
def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str:
if evil:
return "scratch!"
else:
if cute_level > 2:
return "meow <3"
return "meow"
@my_registry.cats("catsie.v3")
def catsie(arg: Cat) -> Cat:
return arg
@my_registry.optimizers("Adam.v1")
def Adam(
learn_rate: FloatOrSeq = 0.001,
*,
beta1: FloatOrSeq = 0.001,
beta2: FloatOrSeq = 0.001,
use_averages: bool = True,
):
"""
Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used
to illustrate how to use the function registry, e.g. with thinc.
"""
@dataclasses.dataclass
class Optimizer:
learn_rate: FloatOrSeq
beta1: FloatOrSeq
beta2: FloatOrSeq
use_averages: bool
return Optimizer(
learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages
)
@my_registry.schedules("warmup_linear.v1")
def warmup_linear(
initial_rate: float, warmup_steps: int, total_steps: int
) -> Iterable[float]:
"""Generate a series, starting from an initial rate, and then with a warmup
period, and then a linear decline. Used for learning rates.
"""
step = 0
while True:
if step < warmup_steps:
factor = step / max(1, warmup_steps)
else:
factor = max(
0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps)
)
yield factor * initial_rate
step += 1
@my_registry.cats("int_cat.v1")
def int_cat(
value_in: Optional[int] = None, value_out: Optional[int] = None
) -> Cat[Optional[int], Optional[int]]:
"""Instantiates cat with integer values."""
return Cat(name="int_cat", value_in=value_in, value_out=value_out)
@my_registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: List[float], beta1: float):
return Adam(learn_rate, beta1=beta1)
@my_registry.schedules("my_cool_repetitive_schedule.v1")
def decaying(base_rate: float, repeat: int) -> List[float]:
return repeat * [base_rate]
@contextlib.contextmanager
def make_tempdir():
d = Path(tempfile.mkdtemp())
yield d
shutil.rmtree(str(d))
|