File: pytorch.py

package info (click to toggle)
python-thinc 9.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,896 kB
  • sloc: python: 17,122; javascript: 1,559; ansic: 342; makefile: 15; sh: 13
file content (255 lines) | stat: -rw-r--r-- 9,340 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import contextlib
import itertools
from io import BytesIO
from typing import Any, Callable, Dict, Optional, cast

import srsly

from ..backends import CupyOps, context_pools, get_current_ops, set_gpu_allocator
from ..compat import torch
from ..optimizers import Optimizer
from ..types import ArgsKwargs, FloatsXd
from ..util import (
    convert_recursive,
    get_torch_default_device,
    iterate_recursive,
    torch2xp,
    xp2torch,
)
from .pytorch_grad_scaler import PyTorchGradScaler
from .shim import Shim


class PyTorchShim(Shim):
    """Interface between a PyTorch model and a Thinc Model. This container is
    *not* a Thinc Model subclass itself.

    mixed_precision:
        Enable mixed-precision. This changes whitelisted ops to run
        in half precision for better performance and lower memory use.
    grad_scaler:
        The gradient scaler to use for mixed-precision training. If this
        argument is set to "None" and mixed precision is enabled, a gradient
        scaler with the default configuration is used.
    device:
        The PyTorch device to run the model on. When this argument is
        set to "None", the default device for the currently active Thinc
        ops is used.
    serialize_model:
        Callback that receives the wrapped PyTorch model as its argument and
        returns a "bytes" representation of the same. The representation should
        contain all the necessary information to fully deserialize the model.
    deserialize_model:
        Callback that receives the default PyTorch model (passed to the constructor), the
        serialized "bytes" representation and a PyTorch device. It should return a
        fully deserialized model on the target device as its result.
    """

    def __init__(
        self,
        model: Any,
        config=None,
        optimizer: Any = None,
        mixed_precision: bool = False,
        grad_scaler: Optional[PyTorchGradScaler] = None,
        device: Optional["torch.device"] = None,
        serialize_model: Optional[Callable[[Any], bytes]] = None,
        deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None,
    ):
        super().__init__(model, config, optimizer)

        if device is None:
            device = get_torch_default_device()
        if model is not None:
            model.to(device)

        if grad_scaler is None:
            grad_scaler = PyTorchGradScaler(mixed_precision)

        grad_scaler.to_(device)

        self._grad_scaler = grad_scaler
        self._mixed_precision = mixed_precision

        self._serialize_model = (
            serialize_model
            if serialize_model is not None
            else default_serialize_torch_model
        )
        self._deserialize_model = (
            deserialize_model
            if deserialize_model is not None
            else default_deserialize_torch_model
        )

        if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps):
            pools = context_pools.get()
            if "pytorch" not in pools:
                from cupy import get_default_memory_pool

                set_gpu_allocator("pytorch")
                get_default_memory_pool().free_all_blocks()

    def __call__(self, inputs, is_train):
        if is_train:
            return self.begin_update(inputs)
        else:
            return self.predict(inputs), lambda a: ...

    @property
    def device(self):
        p = next(self._model.parameters(), None)
        if p is None:
            return get_torch_default_device()
        else:
            return p.device

    def predict(self, inputs: ArgsKwargs) -> Any:
        """Pass inputs through to the underlying PyTorch model, and return the
        output. No conversions are performed. The PyTorch model is set into
        evaluation mode.
        """
        self._model.eval()
        with torch.no_grad():
            with torch.cuda.amp.autocast(self._mixed_precision):
                outputs = self._model(*inputs.args, **inputs.kwargs)
        self._model.train()
        return outputs

    def begin_update(self, inputs: ArgsKwargs):
        """Pass the inputs through to the underlying PyTorch model, keeping
        track of which items in the input are tensors requiring gradients.
        If the model returns a single value, it is converted into a one-element tuple.
        Return the outputs and a callback to backpropagate.
        """
        self._model.train()

        # Note: mixed-precision autocast must not be applied to backprop.
        with torch.cuda.amp.autocast(self._mixed_precision):
            output = self._model(*inputs.args, **inputs.kwargs)

        def backprop(grads):
            # Normally, gradient scaling is applied to the loss of a model. However,
            # since regular thinc layers do not use mixed-precision, we perform scaling
            # locally in this shim. Scaling the loss by a factor, scales the gradients
            # by the same factor (see the chain rule). Therefore, we scale the gradients
            # backprop'ed through the succeeding layer to get the same effect as loss
            # scaling.
            grads.kwargs["grad_tensors"] = self._grad_scaler.scale(
                grads.kwargs["grad_tensors"], inplace=True
            )

            torch.autograd.backward(*grads.args, **grads.kwargs)

            # Unscale weights and check for overflows during backprop.
            grad_tensors = []
            for torch_data in itertools.chain(
                self._model.parameters(),
                iterate_recursive(lambda x: hasattr(x, "grad"), inputs),
            ):
                if torch_data.grad is not None:
                    grad_tensors.append(torch_data.grad)
            found_inf = self._grad_scaler.unscale(grad_tensors)

            # If there was an over/underflow, return zeroed-out gradients.
            if found_inf:
                grad_get = lambda x: x.grad.zero_() if x.grad is not None else x.grad
            else:
                grad_get = lambda x: x.grad

            return convert_recursive(lambda x: hasattr(x, "grad"), grad_get, inputs)

        return output, backprop

    def finish_update(self, optimizer: Optimizer):
        for name, torch_data in self._model.named_parameters():
            if torch_data.grad is not None:
                if (
                    not self._grad_scaler.found_inf
                ):  # Skip weight update if any gradient overflowed.
                    param, grad = optimizer(
                        (self.id, name),
                        cast(FloatsXd, torch2xp(torch_data.data)),
                        cast(FloatsXd, torch2xp(torch_data.grad)),
                    )
                    torch_data.data = xp2torch(
                        param, requires_grad=True, device=torch_data.device
                    )
                torch_data.grad.zero_()

        self._grad_scaler.update()

    @contextlib.contextmanager
    def use_params(self, params):
        key_prefix = f"pytorch_{self.id}_"
        state_dict = {}
        for k, v in params.items():
            if hasattr(k, "startswith") and k.startswith(key_prefix):
                state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device)
        if state_dict:
            backup = {k: v.clone() for k, v in self._model.state_dict().items()}
            self._model.load_state_dict(state_dict)
            yield
            self._model.load_state_dict(backup)
        else:
            yield

    def to_device(self, device_type: str, device_id: int):  # pragma: no cover
        if device_type == "cpu":
            self._model.cpu()
        elif device_type == "gpu":
            self._model.cuda(device_id)
        else:
            msg = f"Invalid device_type: {device_type}. Try 'cpu' or 'gpu'"
            raise ValueError(msg)

    def to_bytes(self):
        model_bytes = self._serialize_model(self._model)
        msg = {"config": self.cfg, "state": model_bytes}
        return srsly.msgpack_dumps(msg)

    def from_bytes(self, bytes_data):
        device = get_torch_default_device()
        msg = srsly.msgpack_loads(bytes_data)
        self.cfg = msg["config"]
        self._model = self._deserialize_model(self._model, msg["state"], device)
        self._grad_scaler.to_(device)
        return self


def default_serialize_torch_model(model: Any) -> bytes:
    """Serializes the parameters of the wrapped PyTorch model to bytes.

    model:
        Wrapped PyTorch model.

    Returns:
        A `bytes` object that encapsulates the serialized model parameters.
    """
    filelike = BytesIO()
    torch.save(model.state_dict(), filelike)
    filelike.seek(0)
    return filelike.getvalue()


def default_deserialize_torch_model(
    model: Any, state_bytes: bytes, device: "torch.device"
) -> Any:
    """Deserializes the parameters of the wrapped PyTorch model and
    moves it to the specified device.

    model:
        Wrapped PyTorch model.
    state_bytes:
        Serialized parameters as a byte stream.
    device:
        PyTorch device to which the model is bound.

    Returns:
        The deserialized model.
    """
    filelike = BytesIO(state_bytes)
    filelike.seek(0)
    model.load_state_dict(torch.load(filelike, map_location=device))
    model.to(device)
    return model