File: test_custom_classes.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (85 lines) | stat: -rw-r--r-- 2,650 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Owner(s): ["module: unknown"]

import unittest
import torch
from torch import ops
import torch.jit as jit
import glob
import os

from torch.testing._internal.common_utils import TestCase, run_tests


def get_custom_class_library_path():
    library_filename = glob.glob("build/*custom_class*")
    assert (len(library_filename) == 1)
    library_filename = library_filename[0]
    path = os.path.abspath(library_filename)
    assert os.path.exists(path), path
    return path

def test_equality(f, cmp_key):
    obj1 = f()
    obj2 = jit.script(f)()
    return (cmp_key(obj1), cmp_key(obj2))

class TestCustomOperators(TestCase):
    def setUp(self):
        ops.load_library(get_custom_class_library_path())

    def test_no_return_class(self):
        def f():
            val = torch.classes._TorchScriptTesting._Foo(5, 3)
            return val.info()
        self.assertEqual(*test_equality(f, lambda x: x))

    def test_constructor_with_args(self):
        def f():
            val = torch.classes._TorchScriptTesting._Foo(5, 3)
            return val
        self.assertEqual(*test_equality(f, lambda x: x.info()))

    def test_function_call_with_args(self):
        def f():
            val = torch.classes._TorchScriptTesting._Foo(5, 3)
            val.increment(1)
            return val

        self.assertEqual(*test_equality(f, lambda x: x.info()))

    def test_function_method_wrong_type(self):
        def f():
            val = torch.classes._TorchScriptTesting._Foo(5, 3)
            val.increment("asdf")
            return val

        with self.assertRaisesRegex(RuntimeError, "Expected"):
            jit.script(f)()

    @unittest.skip("We currently don't support passing custom classes to custom methods.")
    def test_input_class_type(self):
        def f():
            val = torch.classes._TorchScriptTesting._Foo(1, 2)
            val2 = torch.classes._TorchScriptTesting._Foo(2, 3)
            val.combine(val2)
            return val

        self.assertEqual(*test_equality(f, lambda x: x.info()))

    def test_stack_string(self):
        def f():
            val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
            return val.pop()
        self.assertEqual(*test_equality(f, lambda x: x))

    def test_stack_push_pop(self):
        def f():
            val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
            val2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
            val.push(val2.pop())
            return val.pop() + val2.pop()
        self.assertEqual(*test_equality(f, lambda x: x))


if __name__ == "__main__":
    run_tests()