| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 
 | from ..compat import BytesIO
from ..neural._classes.model import Model
try:
    import torch.autograd
    import torch
except ImportError:
    pass
class PytorchWrapper(Model):
    '''Wrap a PyTorch model, so that it has the same API as Thinc models.
    To optimize the model, you'll need to create a PyTorch optimizer and call
    optimizer.step() after each batch --- see examples/wrap_pytorch.py
    '''
    def __init__(self, model):
        Model.__init__(self)
        self._model = model
    def begin_update(self, x_data, drop=0.):
        '''Return the output of the wrapped PyTorch model for the given input,
        along with a callback to handle the backward pass.
        '''
        x_var = torch.autograd.Variable(torch.Tensor(x_data),
                                        requires_grad=True)
        # Make prediction
        y_var = self._model(x_var)
        def backward_pytorch(dy_data, sgd=None):
            dy_var = torch.autograd.Variable(torch.Tensor(dy_data))
            torch.autograd.backward((y_var,), grad_variables=(dy_var,))
            dX = self.ops.asarray(x_var.grad.data)
            if sgd is not None:
                optimizer.step()
            return dX
        return self.ops.asarray(y_var.data), backward
    def to_disk(self, path):
        # TODO: Untested
        torch.save(self._model.state_dict(), str(path))
    def from_disk(self, path):
        # TODO: Untested
        self._model.load_state_dict(torch.load(path))
    def to_bytes(self):
        # TODO: Untested
        filelike = BytesIO()
        torch.save(self._model.state_dict(), filelike)
        return filelike.read()
    def from_bytes(self, data):
        # TODO: Untested
        filelike = BytesIO(data)
        self._model.load_state_dict(torch.load(filelike))
    def to_gpu(self, device_num):
        # TODO: Implement
        raise NotImplementedError
    def to_cpu(self):
        # TODO: Implement
        raise NotImplementedError
    def resize_output(self):
        # TODO: Required for spaCy add label
        raise NotImplementedError
    def resize_input(self):
        # TODO: Not required yet, but should be useful
        raise NotImplementedError
    @contextlib.contextmanager
    def use_params(self, params): # pragma: no cover
        # TODO: Implement
        raise NotImplementedError
 |