File: _param_server.py

package info (click to toggle)
python-thinc 8.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 5,804 kB
  • sloc: python: 15,818; javascript: 1,554; ansic: 342; makefile: 20; sh: 13
file content (80 lines) | stat: -rw-r--r-- 2,741 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Dict, Tuple, Optional, Any

from ..types import FloatsXd
from ..util import get_array_module


KeyT = Tuple[int, str]


class ParamServer:
    """Serve parameters for a single process."""

    _params: Dict[KeyT, FloatsXd] = {}
    _grads: Dict[KeyT, FloatsXd] = {}
    proxy: Optional[Any]

    def __init__(
        self,
        params: Dict[KeyT, FloatsXd] = {},
        grads: Dict[KeyT, FloatsXd] = {},
        *,
        proxy=None
    ):
        self._params = dict(params)
        self._grads = dict(grads)
        # Allow a 'proxy' to be provided to support remote parameters. This
        # is experimental, it's the mechanism we use in the Ray integration.
        self.proxy = proxy

    @property
    def param_keys(self) -> Tuple[KeyT, ...]:
        """Get the names of registered parameter (including unset)."""
        return tuple(self._params.keys())

    @property
    def grad_keys(self) -> Tuple[KeyT, ...]:
        return tuple([key for key in self.param_keys if self.has_grad(*key)])

    def has_param(self, model_id: int, name: str) -> bool:
        return (model_id, name) in self._params

    def has_grad(self, model_id: int, name: str) -> bool:
        return (model_id, name) in self._grads

    def get_param(self, model_id: int, name: str) -> FloatsXd:
        key = (model_id, name)
        if self.proxy is not None:
            self._params[key] = self.proxy.get_param(model_id, name)
        return self._params[key]

    def get_grad(self, model_id: int, name: str) -> FloatsXd:
        key = (model_id, name)
        return self._grads[key]

    def set_param(self, model_id: int, name: str, value: FloatsXd) -> None:
        if self.proxy is not None:
            self.proxy.set_param(model_id, name, value)
        self._params[(model_id, name)] = value

    def set_grad(self, model_id: int, name: str, value: FloatsXd) -> None:
        if self.proxy is not None:
            self.proxy.set_grad(model_id, name, value)
        else:
            self._grads[(model_id, name)] = value

    def inc_grad(self, model_id: int, name: str, value: FloatsXd) -> None:
        key = (model_id, name)
        if self.proxy is not None:
            self.proxy.inc_grad(model_id, name, value)
        elif not self.has_grad(model_id, name):  # pragma: no cover
            if hasattr(value, "copy"):
                # Adjustment for Jax
                self._grads[key] = value.copy()
            elif not value.flags["C_CONTIGUOUS"]:
                xp = get_array_module(value)
                self._grads[(model_id, name)] = xp.ascontiguousarray(value)
            else:
                self._grads[(model_id, name)] = value
        else:
            self._grads[(model_id, name)] += value