File: test_numeric_debugger.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (357 lines) | stat: -rw-r--r-- 14,888 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# Owner(s): ["oncall: quantization"]

import copy
import unittest
from collections import Counter

import torch
from torch.ao.quantization import (
    compare_results,
    CUSTOM_KEY,
    extract_results_from_loggers,
    generate_numeric_debug_handle,
    NUMERIC_DEBUG_HANDLE_KEY,
    prepare_for_propagation_comparison,
)
from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import (
    IS_WINDOWS,
    raise_on_run_directly,
    skipIfCrossRef,
    TestCase,
)


@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
class TestNumericDebugger(TestCase):
    def _assert_each_node_has_debug_handle(self, model) -> None:
        def _assert_node_has_debug_handle(node):
            self.assertTrue(
                CUSTOM_KEY in node.meta
                and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
                f"Node {node} doesn't have debug handle",
            )

        bfs_trace_with_node_process(model, _assert_node_has_debug_handle)

    def _extract_debug_handles(self, model) -> dict[str, int]:
        debug_handle_map: dict[str, int] = {}

        def _extract_debug_handles_from_node(node):
            nonlocal debug_handle_map
            if (
                CUSTOM_KEY in node.meta
                and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
            ):
                debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
                    NUMERIC_DEBUG_HANDLE_KEY
                ]

        bfs_trace_with_node_process(model, _extract_debug_handles_from_node)

        return debug_handle_map

    def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
        prev_decomp_op_to_debug_handle_map: dict[str, int] = {}

        def _extract_debug_handles_with_prev_decomp_op_from_node(node):
            nonlocal prev_decomp_op_to_debug_handle_map
            if (
                CUSTOM_KEY in node.meta
                and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
            ):
                prev_decomp_op = str(node.meta.get("nn_module_stack"))
                debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
                if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
                    prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
                else:
                    assert (
                        prev_decomp_op_to_debug_handle_map[prev_decomp_op]
                        == debug_handle
                    ), f"Node {node} has different debug handle {debug_handle}"
                    "than previous node sharing the same decomp op {prev_decomp_op}"

        bfs_trace_with_node_process(
            model, _extract_debug_handles_with_prev_decomp_op_from_node
        )
        return prev_decomp_op_to_debug_handle_map

    def test_simple(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        self._assert_each_node_has_debug_handle(ep)
        debug_handle_map = self._extract_debug_handles(ep)

        self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))

    def test_control_flow(self):
        m = TestHelperModules.ControlFlow()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)

        self._assert_each_node_has_debug_handle(ep)
        debug_handle_map = self._extract_debug_handles(ep)

        self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))

    def test_quantize_pt2e_preserve_handle(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        m = ep.module()

        quantizer = XNNPACKQuantizer().set_global(
            get_symmetric_quantization_config(is_per_channel=False)
        )
        m = prepare_pt2e(m, quantizer)
        debug_handle_map = self._extract_debug_handles(m)
        res_counter = Counter(debug_handle_map.values())
        repeated_debug_handle_ids = [1, 2, 3]
        # 3 ids were repeated because we copy over the id from node to its output observer
        # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
        for dh_id in repeated_debug_handle_ids:
            self.assertEqual(res_counter[dh_id], 2)

        m(*example_inputs)
        m = convert_pt2e(m)
        self._assert_each_node_has_debug_handle(ep)
        debug_handle_map = self._extract_debug_handles(m)
        res_counter = Counter(debug_handle_map.values())
        # same set of ids where repeated, because we copy over the id from observer/fake_quant to
        # dequantize node
        repeated_debug_handle_ids = [1, 2, 3]
        for dh_id in repeated_debug_handle_ids:
            self.assertEqual(res_counter[dh_id], 2)

    def test_copy_preserve_handle(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = torch.export.export(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)

        self._assert_each_node_has_debug_handle(ep)
        debug_handle_map_ref = self._extract_debug_handles(ep)

        ep_copy = copy.copy(ep)
        debug_handle_map = self._extract_debug_handles(ep_copy)

        self._assert_each_node_has_debug_handle(ep)
        self.assertEqual(debug_handle_map, debug_handle_map_ref)

    def test_deepcopy_preserve_handle(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = torch.export.export(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)

        debug_handle_map_ref = self._extract_debug_handles(ep)
        ep_copy = copy.deepcopy(ep)
        debug_handle_map = self._extract_debug_handles(ep_copy)

        self._assert_each_node_has_debug_handle(ep)
        self.assertEqual(debug_handle_map, debug_handle_map_ref)

    @skipIfCrossRef  # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
    # trace of the mode torch function impl doesn't match the traced graph stored lineno.
    def test_re_export_preserve_handle(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        m = ep.module()

        self._assert_each_node_has_debug_handle(ep)
        debug_handle_map_ref = self._extract_debug_handles(ep)

        ep_reexport = export_for_training(m, example_inputs, strict=True)

        self._assert_each_node_has_debug_handle(ep_reexport)
        debug_handle_map = self._extract_debug_handles(ep_reexport)

        self.assertEqual(debug_handle_map, debug_handle_map_ref)

    def test_run_decompositions_same_handle_id(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)

        self._assert_each_node_has_debug_handle(ep)
        debug_handle_map_ref = self._extract_debug_handles(ep)

        ep_copy = copy.copy(ep)
        ep_copy = ep_copy.run_decompositions()

        self._assert_each_node_has_debug_handle(ep_copy)
        debug_handle_map = self._extract_debug_handles(ep_copy)

        # checking the map still has the same ids, the node may change
        self.assertEqual(
            set(debug_handle_map.values()), set(debug_handle_map_ref.values())
        )

    def test_run_decompositions_map_handle_to_new_nodes(self):
        test_models = [
            TestHelperModules.TwoLinearModule(),
            TestHelperModules.Conv2dThenConv1d(),
        ]

        for m in test_models:
            example_inputs = m.example_inputs()
            ep = export_for_training(m, example_inputs, strict=True)
            generate_numeric_debug_handle(ep)

            self._assert_each_node_has_debug_handle(ep)
            pre_decomp_to_debug_handle_map_ref = (
                self._extract_debug_handles_with_prev_decomp_op(ep)
            )

            ep_copy = copy.copy(ep)
            ep_copy = ep_copy.run_decompositions()
            self._assert_each_node_has_debug_handle(ep_copy)
            pre_decomp_to_debug_handle_map = (
                self._extract_debug_handles_with_prev_decomp_op(ep_copy)
            )

            # checking the map still has the same ids, the node may change
            self.assertEqual(
                pre_decomp_to_debug_handle_map, pre_decomp_to_debug_handle_map_ref
            )

    def test_prepare_for_propagation_comparison(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        m = ep.module()
        m_logger = prepare_for_propagation_comparison(m)
        ref = m(*example_inputs)
        res = m_logger(*example_inputs)

        from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger

        loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)]
        self.assertEqual(len(loggers), 3)
        self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
        self.assertEqual(res, ref)

    def test_extract_results_from_loggers(self):
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        m = ep.module()
        m_ref_logger = prepare_for_propagation_comparison(m)

        quantizer = XNNPACKQuantizer().set_global(
            get_symmetric_quantization_config(is_per_channel=False)
        )
        m = prepare_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)
        m_quant_logger = prepare_for_propagation_comparison(m)

        m_ref_logger(*example_inputs)
        m_quant_logger(*example_inputs)
        ref_results = extract_results_from_loggers(m_ref_logger)
        quant_results = extract_results_from_loggers(m_quant_logger)
        comparison_results = compare_results(ref_results, quant_results)
        for node_summary in comparison_results.values():
            if len(node_summary.results) > 0:
                self.assertGreaterEqual(node_summary.results[0].sqnr, 35)

    def test_extract_results_from_loggers_list_output(self):
        m = TestHelperModules.Conv2dWithSplit()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        m = ep.module()
        m_ref_logger = prepare_for_propagation_comparison(m)

        quantizer = XNNPACKQuantizer().set_global(
            get_symmetric_quantization_config(is_per_channel=False)
        )
        m = prepare_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)
        m_quant_logger = prepare_for_propagation_comparison(m)

        m_ref_logger(*example_inputs)
        m_quant_logger(*example_inputs)
        ref_results = extract_results_from_loggers(m_ref_logger)
        quant_results = extract_results_from_loggers(m_quant_logger)
        comparison_results = compare_results(ref_results, quant_results)
        for node_summary in comparison_results.values():
            if len(node_summary.results) > 0:
                sqnr = node_summary.results[0].sqnr
                if isinstance(sqnr, list):
                    for sqnr_i in sqnr:
                        self.assertGreaterEqual(sqnr_i, 35)
                else:
                    self.assertGreaterEqual(sqnr, 35)

    def test_added_node_gets_unique_id(self) -> None:
        m = TestHelperModules.Conv2dThenConv1d()
        example_inputs = m.example_inputs()
        ep = export_for_training(m, example_inputs, strict=True)
        generate_numeric_debug_handle(ep)
        ref_handles = self._extract_debug_handles(ep)
        ref_counter = Counter(ref_handles.values())
        for k, v in ref_counter.items():
            self.assertEqual(
                v,
                1,
                msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
            )

        # Now that we have unique ids, add a new node into the graph and re-generate
        # to make sure that the new node gets a unique id.
        last_node = next(iter(reversed(ep.graph.nodes)))
        with ep.graph.inserting_before(last_node):
            arg = last_node.args[0]
            self.assertIsInstance(arg, (list, tuple))
            arg = arg[0]
            # Add a function that only requires a single tensor input.
            n = ep.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
            arg.replace_all_uses_with(n, lambda x: x != n)
        ep.graph_module.recompile()

        # Regenerate handles, make sure only the new relu node has a new id, and
        # it doesn't clash with any of the existing ids.
        generate_numeric_debug_handle(ep)

        self._assert_each_node_has_debug_handle(ep)
        handles_after_modification = self._extract_debug_handles(ep)
        handles_counter = Counter(handles_after_modification.values())
        for name, handle in ref_handles.items():
            self.assertIn(name, handles_after_modification)
            # Check that handle was unchanged.
            self.assertEqual(handles_after_modification[name], handle)
            # Check that total count was unchanged.
            ref_count = ref_counter[handle]
            after_count = handles_counter[handle]
            self.assertEqual(
                after_count,
                ref_count,
                msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
            )

        # Check for relu specifically. Avoid hardcoding the handle id since it
        # may change with future node ordering changes.
        self.assertNotEqual(handles_after_modification["relu_default"], 0)
        self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)


if __name__ == "__main__":
    raise_on_run_directly("test/test_quantization.py")