File: wrappers.py

package info (click to toggle)
python-thinc 6.12.1-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 1,388 kB
  • sloc: python: 7,120; ansic: 6,257; makefile: 19; sh: 11
file content (76 lines) | stat: -rw-r--r-- 2,346 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from ..compat import BytesIO
from ..neural._classes.model import Model

try:
    import torch.autograd
    import torch
except ImportError:
    pass


class PytorchWrapper(Model):
    '''Wrap a PyTorch model, so that it has the same API as Thinc models.
    To optimize the model, you'll need to create a PyTorch optimizer and call
    optimizer.step() after each batch --- see examples/wrap_pytorch.py
    '''
    def __init__(self, model):
        Model.__init__(self)
        self._model = model

    def begin_update(self, x_data, drop=0.):
        '''Return the output of the wrapped PyTorch model for the given input,
        along with a callback to handle the backward pass.
        '''
        x_var = torch.autograd.Variable(torch.Tensor(x_data),
                                        requires_grad=True)
        # Make prediction
        y_var = self._model(x_var)
        def backward_pytorch(dy_data, sgd=None):
            dy_var = torch.autograd.Variable(torch.Tensor(dy_data))
            torch.autograd.backward((y_var,), grad_variables=(dy_var,))
            dX = self.ops.asarray(x_var.grad.data)
            if sgd is not None:
                optimizer.step()
            return dX
        return self.ops.asarray(y_var.data), backward

    def to_disk(self, path):
        # TODO: Untested
        torch.save(self._model.state_dict(), str(path))

    def from_disk(self, path):
        # TODO: Untested
        self._model.load_state_dict(torch.load(path))

    def to_bytes(self):
        # TODO: Untested
        filelike = BytesIO()
        torch.save(self._model.state_dict(), filelike)
        return filelike.read()

    def from_bytes(self, data):
        # TODO: Untested
        filelike = BytesIO(data)
        self._model.load_state_dict(torch.load(filelike))

    def to_gpu(self, device_num):
        # TODO: Implement
        raise NotImplementedError

    def to_cpu(self):
        # TODO: Implement
        raise NotImplementedError

    def resize_output(self):
        # TODO: Required for spaCy add label
        raise NotImplementedError

    def resize_input(self):
        # TODO: Not required yet, but should be useful
        raise NotImplementedError

    @contextlib.contextmanager
    def use_params(self, params): # pragma: no cover
        # TODO: Implement
        raise NotImplementedError