File: test_custom_backend.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (56 lines) | stat: -rw-r--r-- 1,701 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Owner(s): ["module: unknown"]

import os
import tempfile
import torch

from backend import Model, to_custom_backend, get_custom_backend_library_path
from torch.testing._internal.common_utils import TestCase, run_tests


class TestCustomBackend(TestCase):
    def setUp(self):
        # Load the library containing the custom backend.
        self.library_path = get_custom_backend_library_path()
        torch.ops.load_library(self.library_path)
        # Create an instance of the test Module and lower it for
        # the custom backend.
        self.model = to_custom_backend(torch.jit.script(Model()))

    def test_execute(self):
        """
        Test execution using the custom backend.
        """
        a = torch.randn(4)
        b = torch.randn(4)
        # The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
        expected = (a + b, a - b)
        out = self.model(a, b)
        self.assertTrue(expected[0].allclose(out[0]))
        self.assertTrue(expected[1].allclose(out[1]))

    def test_save_load(self):
        """
        Test that a lowered module can be executed correctly
        after saving and loading.
        """
        # Test execution before saving and loading to make sure
        # the lowered module works in the first place.
        self.test_execute()

        # Save and load.
        f = tempfile.NamedTemporaryFile(delete=False)
        try:
            f.close()
            torch.jit.save(self.model, f.name)
            loaded = torch.jit.load(f.name)
        finally:
            os.unlink(f.name)
        self.model = loaded

        # Test execution again.
        self.test_execute()


if __name__ == "__main__":
    run_tests()