File: test_alias_analysis.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (156 lines) | stat: -rw-r--r-- 5,698 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Owner(s): ["oncall: jit"]

import torch
from torch._C import parse_ir
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestAliasAnalysis(JitTestCase):
    def test_becomes_wildcard_annotations(self):
        graph_str = """
        graph(%a.1 : Tensor, %b.1 : Tensor):
            %11 : NoneType = prim::Constant()
            %8 : int = prim::Constant[value=0]()
            %7 : int = prim::Constant[value=1]()
            %x.1 : Tensor = aten::add(%a.1, %b.1, %7)
            %y.1 : Tensor[] = aten::split(%x.1, %7, %8)
            return ()
        """
        graph = parse_ir(graph_str)
        alias_db = graph.alias_db()
        split_node = graph.findNode("aten::split")
        # split input enters wildcard set, list initalized as containing wildcard set
        self.assertTrue(
            alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())
        )
        # because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
        self.assertTrue(
            alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs()))
        )

    def test_nested_list_construct_not_wildcard(self):
        @torch.jit.script
        def foo(x):
            y = torch.rand([2, 2])
            return [y]

        graph = foo.graph
        graph.alias_db()
        alias_db = graph.alias_db()
        ten_construct = graph.findNode("aten::rand").output()
        output = next(graph.outputs())
        self.assertTrue(alias_db.may_contain_alias(ten_construct, output))
        self.assertFalse(
            alias_db.may_contain_alias(next(graph.inputs()), ten_construct)
        )

    def test_recursive_calls(self):
        @torch.jit.script
        def foo(x, y):
            x.add_(1)
            return x + y

        @torch.jit.script
        def caller():
            a = torch.rand([2, 2])
            b = torch.ones([2, 2])
            out1 = foo(a, b)
            c = torch.rand([1])
            d = torch.ones([2])
            out2 = foo(d, c)
            return out1, out2

        isFrozen = False
        descend_function_calls = True
        alias_db = caller.graph.alias_db(isFrozen, descend_function_calls)
        func_calls = caller.graph.findAllNodes("prim::CallFunction")
        self.assertEqual(len(func_calls), 2)
        for node in func_calls:
            inps = list(node.inputs())
            self.assertTrue(alias_db.has_writers(inps[1]))
            self.assertFalse(alias_db.has_writers(inps[2]))

        class Mod(torch.nn.Module):
            def forward(self):
                a = torch.rand([2, 2])
                b = torch.ones([2, 2])
                out1 = self.foo2(a, b)
                c = torch.rand([1])
                d = torch.ones([2])
                out2 = self.foo2(d, c)
                return out1, out2

            def foo2(self, x, y):
                x.add_(1)
                return x + y

        mod = torch.jit.script(Mod())
        alias_db = mod.graph.alias_db(isFrozen, descend_function_calls)
        func_calls = mod.graph.findAllNodes("prim::CallMethod")
        self.assertEqual(len(func_calls), 2)
        for node in func_calls:
            inps = list(node.inputs())
            self.assertTrue(alias_db.has_writers(inps[1]))
            self.assertFalse(alias_db.has_writers(inps[2]))

    def test_multiple_compilation_units(self):
        # This is a repro of an internal issue we saw.
        # Here, we have a large number (40) of modules each with the same name (MyModuleCUTest).
        # AliasDB uses some hash tables that hash on types; each of these 40 modules are not
        # identical because they have different compilation units, but they have the same name.
        # Therefore, if we hash only on the module name (which we previously did), we will have
        # hash collisions for all of these module types.
        #
        # flat_hash_map has very bad performance (exponential) for this hash collision behavior.
        # This OOMs prior to the fix.
        N = 40

        class MultiTmpFile:
            def __init__(self, N):
                self.N = N
                self.ctxs = [
                    TemporaryFileName(mode="w", suffix=".py") for _ in range(N)
                ]

            def __enter__(self):
                return [x.__enter__() for x in self.ctxs]

            def __exit__(self, exc_type, exc_value, traceback):
                return [x.__exit__(exc_type, exc_value, traceback) for x in self.ctxs]

        class ModuleWrapper(torch.nn.Module):
            def __init__(self, module_list):
                super().__init__()
                self.module_list = module_list

            def forward(self, x):
                for mod in self.module_list:
                    x = mod(x)
                return x

        with MultiTmpFile(N) as fnames:
            module_list = torch.nn.ModuleList()
            global MyModuleCUTest

            class MyModuleCUTest(torch.nn.Module):
                def forward(self, x):
                    return x + 2

            for _, fname in enumerate(fnames):
                mod = torch.jit.script(MyModuleCUTest())
                torch.jit.save(mod, fname)
                loaded_mod = torch.jit.load(fname)
                module_list.append(loaded_mod)

            mod = ModuleWrapper(module_list)
            mod = torch.jit.script(mod)
            mod(torch.zeros((2, 2)))