File: test_torchscriptwrapper.py

package info (click to toggle)
python-thinc 8.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 5,804 kB
  • sloc: python: 15,818; javascript: 1,554; ansic: 342; makefile: 20; sh: 13
file content (25 lines) | stat: -rw-r--r-- 810 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pytest
import numpy

from thinc.api import PyTorchWrapper_v2, TorchScriptWrapper_v1
from thinc.api import pytorch_to_torchscript_wrapper
from thinc.compat import has_torch, torch


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)])
def test_pytorch_script(nN, nI, nO):

    model = PyTorchWrapper_v2(torch.nn.Linear(nI, nO)).initialize()
    script_model = pytorch_to_torchscript_wrapper(model)

    X = numpy.random.randn(nN, nI).astype("f")
    Y = model.predict(X)
    Y_script = script_model.predict(X)
    numpy.testing.assert_allclose(Y, Y_script)

    serialized = script_model.to_bytes()
    script_model2 = TorchScriptWrapper_v1()
    script_model2.from_bytes(serialized)

    numpy.testing.assert_allclose(Y, script_model2.predict(X))