File: test_profiler.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (284 lines) | stat: -rw-r--r-- 10,063 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Owner(s): ["module: inductor"]
import json
import os
import tempfile
import unittest
from typing import Callable, Optional

import torch
import torch._inductor.test_case
import torch._inductor.utils
from torch import _dynamo as torchdynamo
from torch._inductor import config
from torch.profiler import ProfilerActivity
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.utils._triton import has_triton


HAS_TRITON = has_triton()


class DynamoProfilerTests(torch._inductor.test_case.TestCase):
    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
    def test_inductor_profiling_triton_launch(self):
        # Verify that we get some sort of CPU-side indication of triton kernel launches
        # in the profile traces. Currently, those appear as `cuLaunchKernel`. If this
        # detail changes, the test can be updated or removed.
        @torch.compile
        def fn(x, y):
            return (x + y).sin().cos()

        x, y = (torch.rand((4, 4), device="cuda") for _ in range(2))

        with torch.profiler.profile() as prof:
            fn(x, y)

        with TemporaryFileName(mode="w+") as fname:
            prof.export_chrome_trace(fname)
            with open(fname) as f:
                trace_json = json.load(f)

        self.assertTrue("traceEvents" in trace_json)
        events = trace_json["traceEvents"]

        kernel_name = "hipModuleLaunchKernel" if torch.version.hip else "cuLaunchKernel"

        def nameMatchesLaunchKernel(event_name):
            return kernel_name in event_name

        self.assertTrue(
            any(("name" in event and kernel_name == event["name"]) for event in events)
        )

    def _test_profiling_kernel_names(
        self, fn, args, kernel_name_str: str, check_fn: Optional[Callable] = None
    ):
        """
        We expect a record_function event to be added on the CPU side, surrounding
        the launch of each triton kernel.
        """
        fn_opt = torch.compile(fn)

        for _ in range(2):
            fn_opt(*args)

        if check_fn is not None:
            check_fn()

        with torch.profiler.profile(
            activities=[ProfilerActivity.CPU], record_shapes=True
        ) as prof:
            fn_opt(*args)

        # The name of the kernel is expected to match the name of the kernel in debug
        # files etc. The name could change in the future, but it seems reasonable that
        # the name should always contain "triton" and "kernel_name_str" - e.g. if the
        # kernel contains a sin op, it should probably contain "str" in the name.
        # If this changes in the future, feel free to change the assertion here.
        # Debugging tips: you can add prof.export_chrome_trace("test.json") inline in
        # this test, and then view test.json in chrome://tracing to see the trace.
        self.assertTrue(
            any(
                (
                    hasattr(event, "name")
                    and kernel_name_str in event.name
                    and "triton" in event.name
                )
                for event in prof.events()
            )
        )
        return prof.events()

    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
    def test_inductor_profiling_kernel_names_pointwise(self):
        def fn(x, y):
            return (x + y).sin().cos()

        args = [torch.rand((4, 4), device="cuda") for _ in range(2)]

        events = self._test_profiling_kernel_names(fn, args, "sin")
        event_found = False
        for event in events:
            if event.name == "triton_poi_fused_add_cos_sin_0":
                event_found = True
                self.assertTrue(event.input_shapes == [[4, 4], [4, 4], [4, 4], []])
        self.assertTrue(event_found)

    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
    def test_inductor_profiling_kernel_names_template(self):
        with config.patch(
            {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
        ):

            def fn(x, y):
                return x @ y

            args = [torch.rand((4, 4), device="cuda") for _ in range(2)]

            def check_fn():
                # test_profiling_kernel_names will check this before asserting mm is in the trace.
                # reason: sometimes testing runs on machines with not enough SMs, and autotuning is skipped.
                if (
                    torch._dynamo.utils.counters["inductor"][
                        "select_algorithm_autotune"
                    ]
                    == 0
                ):
                    raise unittest.SkipTest(
                        "select_algorithm didn't run, we probably won't get profiling data. GPU might not have enough SMs."
                    )

            events = self._test_profiling_kernel_names(fn, args, "mm", check_fn)

            event_found = False
            for event in events:
                if event.name == "triton_tem_fused_mm_0":
                    event_found = True
                    self.assertTrue(event.input_shapes == [[4, 4], [4, 4], [4, 4]])
            self.assertTrue(event_found)

    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
    def test_inductor_profiling_kernel_names_foreach(self):
        with config.patch(
            {"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
        ):

            def fn(x, y):
                return torch._foreach_add(x, y)

            x = [torch.rand((4, 4), device="cuda") for _ in range(3)]
            y = [torch.rand((4, 4), device="cuda") for _ in range(3)]

            args = (x, y)

            events = self._test_profiling_kernel_names(fn, args, "_for_")
            event_found = False
            for event in events:
                if event.name == "triton_for_fused_0":
                    event_found = True
                    self.assertTrue(
                        event.input_shapes
                        == [
                            [4, 4],
                            [4, 4],
                            [4, 4],
                            [4, 4],
                            [4, 4],
                            [4, 4],
                            [4, 4],
                            [4, 4],
                            [4, 4],
                        ]
                    )
            self.assertTrue(event_found)

    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
    def test_inductor_profiling_triton_hooks(self):
        from triton.compiler import CompiledKernel  # @manual

        hooks_called = {"enter": False, "exit": False}

        def launch_enter_hook(lazy_dict):
            hooks_called["enter"] = True

        def launch_exit_hook(lazy_dict):
            hooks_called["exit"] = True

        CompiledKernel.launch_enter_hook = launch_enter_hook
        CompiledKernel.launch_exit_hook = launch_exit_hook

        def fn(x, y):
            return torch._foreach_add(x, y)

        x = [torch.rand((4, 4), device="cuda") for _ in range(3)]
        y = [torch.rand((4, 4), device="cuda") for _ in range(3)]

        args = (x, y)
        fn_opt = torch.compile(fn)
        fn_opt(*args)

        self.assertTrue(hooks_called["enter"])
        self.assertTrue(hooks_called["exit"])

    @unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
    def test_pt2_triton_attributes(self):
        from torch._inductor.codecache import code_hash

        device = "cuda"
        debug = False  # set to True to get output file

        @torchdynamo.optimize("inductor")
        def fn(a, b, c):
            x = torch.nn.functional.linear(a, b)
            x = x + c
            return x.cos()

        a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3))

        inputs = [a, b, c]
        with config.patch(compile_threads=1):
            fn(*inputs)

        fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=not debug)
        fp.close()

        with torch.profiler.profile(
            activities=torch.profiler.supported_activities(),
            record_shapes=True,
            schedule=torch.profiler.schedule(
                skip_first=3, wait=1, warmup=1, active=2, repeat=1
            ),
        ) as prof:
            for idx in range(10):
                fn(*inputs)
                prof.step()

        prof.export_chrome_trace(fp.name)
        print("Trace written to {fp.name}, set debug=True to retain file.")

        triton_events = []
        with open(fp.name) as f:
            trace_json = json.load(f)
            triton_events = [
                event
                for event in trace_json["traceEvents"]
                if "kernel_backend" in event.get("args", {}).keys()
            ]

        print(triton_events)
        self.assertEqual(len(triton_events), 2)

        def get_hash(kernel_file: str) -> str:
            with open(kernel_file) as f:
                kernel_src = f.read()
            return code_hash(kernel_src.strip())

        def check_triton_event(e) -> None:
            args = e.get("args", {})
            self.assertNotEqual(args, {}, msg=f"event = {e}")

            self.assertEqual(args["kernel_backend"], "triton", msg=f"event = {e}")

            self.assertTrue("stream" in args, msg=f"event = {e}")
            self.assertTrue("grid" in args, msg=f"event = {e}")
            self.assertTrue(args["grid"].startswith("grid"), msg=f"event = {e}")

            self.assertTrue("kernel_file" in args, msg=f"event = {e}")
            kernel_file = args["kernel_file"]
            self.assertTrue(os.path.isfile(kernel_file), msg=f"event = {e}")

            self.assertTrue("kernel_hash" in args, msg=f"event = {e}")
            self.assertEqual(
                args["kernel_hash"], get_hash(kernel_file), msg=f"event = {e}"
            )

        for e in triton_events:
            check_triton_event(e)


if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    if HAS_CUDA:
        run_tests()