File: test_fx_traceback.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (172 lines) | stat: -rw-r--r-- 6,018 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Owner(s): ["module: fx"]
import json

import torch
from torch._inductor.compile_fx import aot_export_module
from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction
from torch.testing._internal.common_utils import TestCase


class TestFXNodeSource(TestCase):
    def test_node_source(self):
        node_source = NodeSource(
            node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
        )
        self.assertExpectedInline(
            node_source.print_readable().strip(),
            """(name=, pass_name=test_pass, action=create, graph_id=-1)""",
        )
        dummy_source_dict = {
            "name": "",
            "target": "",
            "pass_name": "test_pass",
            "action": NodeSourceAction.CREATE,
            "graph_id": -1,
            "from_node": [],
        }
        self.assertEqual(
            node_source.to_dict(),
            dummy_source_dict,
        )

        # Dummy node
        node = torch.fx.Node(
            graph=torch.fx.Graph(),
            name="add",
            op="call_function",
            target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
            args=(torch.tensor(3), torch.tensor(4)),
            kwargs={},
        )
        node.meta["from_node"] = [node_source]

        graph_id = id(node.graph)
        node_source = NodeSource(
            node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
        )
        self.assertExpectedInline(
            node_source.print_readable().strip(),
            f"""\
(name=add, pass_name=test_pass, action=create, graph_id={graph_id})
    (name=, pass_name=test_pass, action=create, graph_id=-1)""",
        )
        self.assertEqual(
            node_source.to_dict(),
            {
                "name": "add",
                "target": "aten.add.Tensor",
                "pass_name": "test_pass",
                "action": NodeSourceAction.CREATE,
                "graph_id": graph_id,
                "from_node": [dummy_source_dict],
            },
        )

    def test_graph_provenance(self):
        def check_node_source(node_source_dict, name, pass_name, action):
            self.assertEqual(node_source_dict["name"], name)
            self.assertEqual(node_source_dict["pass_name"], pass_name)
            self.assertEqual(node_source_dict["action"], action)

        def get_first_node_source_and_check(node_source_dict):
            """
            Get the first node source from the from_node list.
            """
            self.assertEqual(len(node_source_dict["from_node"]), 1)
            return node_source_dict["from_node"][0]

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.fc1 = torch.nn.Linear(10, 16)
                self.relu = torch.nn.ReLU()
                self.fc2 = torch.nn.Linear(16, 1)
                self.sigmoid = torch.nn.Sigmoid()

            def forward(self, x):
                x = self.fc1(x)
                x = self.relu(x)
                x = self.fc2(x)
                x = self.sigmoid(x)
                return (x,)

        model = Model()
        example_inputs = (torch.randn(8, 10),)
        ep = torch.export.export(
            model,
            example_inputs,
        )
        gm = ep.module()
        provenance = get_graph_provenance_json(gm.graph)
        provenance = json.loads(provenance)
        self.assertEqual(
            set(provenance.keys()), {"relu", "linear", "sigmoid", "linear_1"}
        )

        # Check node "linear" is created from node "x" in PropagateUnbackedSymInts
        key_provenance = provenance["linear"]
        self.assertEqual(len(key_provenance), 1)
        key_provenance = key_provenance[0]
        check_node_source(
            key_provenance,
            "x",
            "Interpreter_PropagateUnbackedSymInts",
            NodeSourceAction.CREATE,
        )

        # Check node "x" is then created from another node "x" in FlattenInputOutputSignature
        key_provenance = get_first_node_source_and_check(key_provenance)
        check_node_source(
            key_provenance,
            "x",
            "Interpreter_FlattenInputOutputSignature",
            NodeSourceAction.CREATE,
        )

        gm, graph_signature = aot_export_module(
            gm,
            example_inputs,
            trace_joint=False,
        )

        provenance = get_graph_provenance_json(gm.graph)
        provenance = json.loads(provenance)

        self.assertEqual(
            set(provenance.keys()), {"t", "addmm", "relu", "t_1", "addmm_1", "sigmoid"}
        )
        for key in ["t", "addmm"]:
            # The node provenance hierarchy should be:
            # t -> linear -> x -> x
            #
            # x -> y means x is created from y

            key_provenance = provenance[key]
            self.assertEqual(len(key_provenance), 1)
            key_provenance = key_provenance[0]

            # Check node "t" and "addmm" is created from node "linear" in PropagateUnbackedSymInts
            check_node_source(
                key_provenance,
                "linear",
                "Interpreter_PropagateUnbackedSymInts",
                NodeSourceAction.CREATE,
            )

            # Check node "linear" is then created from node "x" in PropagateUnbackedSymInts
            key_provenance = get_first_node_source_and_check(key_provenance)
            check_node_source(
                key_provenance,
                "x",
                "Interpreter_PropagateUnbackedSymInts",
                NodeSourceAction.CREATE,
            )

            # Check node "x" is then created from another node "x" in FlattenInputOutputSignature
            key_provenance = get_first_node_source_and_check(key_provenance)
            check_node_source(
                key_provenance,
                "x",
                "Interpreter_FlattenInputOutputSignature",
                NodeSourceAction.CREATE,
            )