File: test_autoheuristic.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (173 lines) | stat: -rw-r--r-- 7,265 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Owner(s): ["module: inductor"]
import os
import unittest

import torch
import torch._inductor.config as inductor_config
from torch._dynamo.device_interface import get_interface_for_device
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import get_gpu_shared_memory
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_A100, IS_H100


@skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack")
class AutoHeuristicTest(TestCase):
    def count_lines_in_file(self, file_path):
        with open(file_path) as file:
            line_count = sum(1 for line in file)
        return line_count

    def run_mm(self):
        def f(a, b):
            return torch.mm(a, b)

        cf = torch.compile(f)
        a = torch.randn(2047, 2048, device=GPU_TYPE, dtype=torch.float16)
        b = torch.randn(2048, 2048, device=GPU_TYPE, dtype=torch.float16)
        cf(a, b)

    def get_path_to_autoheuristic_log(self, name):
        device_name = AutoHeuristic.get_device_identifier()
        path = cache_dir() + "/autoheuristic/" + device_name + "/" + name + ".txt"
        return path

    def test_autoheuristic_pad_mm_default(self):
        # this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value
        self.run_mm()
        self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))

    @inductor_config.patch(autoheuristic_collect="foo")
    def test_autoheuristic_pad_mm_off(self):
        # this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm"
        self.run_mm()
        self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))

    def assert_autoheuristic_collected_data(self):
        self.run_mm()
        device_name = AutoHeuristic.get_device_identifier()
        path = self.get_path_to_autoheuristic_log("pad_mm")
        self.assertTrue(os.path.exists(path))
        num_lines = self.count_lines_in_file(path)

        # 1 line for metadata, 1 line for header, 1 line per choice (orig, padded)
        self.assertEqual(num_lines, 4)

    @inductor_config.patch(autoheuristic_collect="pad_mm")
    def test_autoheuristic_pad_mm_collect_data(self):
        # this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm"
        self.assert_autoheuristic_collected_data()

    @inductor_config.patch(autoheuristic_collect="foo,pad_mm")
    def test_autoheuristic_pad_mm_collect_data2(self):
        # this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm"
        self.assert_autoheuristic_collected_data()

    @inductor_config.patch(autoheuristic_collect="test")
    def test_autoheuristic(self):
        # test basic functionality of autoheuristic
        def fallback():
            return "fallback"

        choices = ["a", "b", "c"]

        def feedback_fn(choice):
            if choice == "a":
                return 1
            elif choice == "b":
                return 2
            elif choice == "c":
                return 3
            else:
                raise RuntimeError("unexpected choice")

        feedback = LocalFeedback(feedback_fn)
        context = AHContext()
        context.add_feature("fa", 5)
        name = "test"
        autoheuristic = AutoHeuristic(fallback, choices, feedback, context, name)

        # when autoheuristic is configured to only collect data, we always return fallback
        self.assertEqual(autoheuristic.get_choice(), "fallback")
        self.assertEqual(autoheuristic.get_collected_feedback("a"), 1)
        self.assertEqual(autoheuristic.get_collected_feedback("b"), 2)
        self.assertEqual(autoheuristic.get_collected_feedback("c"), 3)

        path = self.get_path_to_autoheuristic_log(name)
        self.assertTrue(os.path.exists(path))
        num_lines = self.count_lines_in_file(path)
        self.assertEqual(num_lines, 5)

        shared_memory = get_gpu_shared_memory()
        (fst, snd) = get_interface_for_device(GPU_TYPE).get_device_capability()

        with open(path) as file:
            lines = file.readlines()
            self.assertTrue('"numerical_features": ["fa"]' in lines[0])
            self.assertTrue('"categorical_features": []' in lines[0])
            self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0])
            self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0])
            self.assertTrue('"name": "test"' in lines[0])
            self.assertEqual("fa,choice,feedback", lines[1].rstrip())
            self.assertEqual("5,a,1", lines[2].rstrip())
            self.assertEqual("5,b,2", lines[3].rstrip())
            self.assertEqual("5,c,3", lines[4].rstrip())

    @unittest.skipIf(not IS_A100, "heuristic only run on A100")
    @inductor_config.patch(autoheuristic_use="pad_mm")
    def test_autoheuristic_a100(self):
        # Make sure heuristic does not break anything
        # TODO (AlnisM): Find a way to check whether heuristic is used
        self.run_mm()

    @unittest.skipIf(not IS_H100, "heuristic only run on H100")
    @inductor_config.patch(autoheuristic_use="pad_mm")
    def test_autoheuristic_h100(self):
        # Make sure heuristic does not break anything
        # TODO (AlnisM): Find a way to check whether heuristic is used
        self.run_mm()

    def run_mixed_mm(self):
        def fn(a, b):
            return torch.mm(a, b.to(a.dtype))

        a = torch.randn(8, 1024, device=GPU_TYPE, dtype=torch.float16)
        b = torch.randint(
            -128, 127, (1024, 1024), dtype=torch.int8, device=GPU_TYPE
        ).t()
        torch.compile(fn, mode="max-autotune-no-cudagraphs")(a, b)

    # have to set autoheuristic_use="" because if autoheuristic_use="mixed_mm",
    # autoheuristic creates a precompile key, puts it into the registry, and then
    # a choice made by the heuristic might be added to the list of choices
    # and if select_algorithm now creates a new precompile key, it will be
    # different from the precompile key created by autoheuristic
    @inductor_config.patch(
        autoheuristic_collect="mixed_mm",
        autoheuristic_use="",
        fx_graph_cache=False,
        fx_graph_remote_cache=False,
    )
    def test_global_feedback(self):
        self.run_mixed_mm()
        path = self.get_path_to_autoheuristic_log("mixed_mm")
        self.assertTrue(os.path.exists(path))
        num_lines = self.count_lines_in_file(path)

        # 1 line for metadata, 1 line for header
        # 1 line for fallback + at least 1 config
        self.assertTrue(num_lines > 4)

    @inductor_config.patch(autoheuristic_use="mixed_mm")
    @unittest.skipIf(not IS_A100, "heuristic only run on A100")
    def test_mixed_mm_a100(self):
        self.run_mixed_mm()
        # TODO (AlnisM): Find a way to check whether heuristic is used


if __name__ == "__main__":
    if HAS_GPU:
        run_tests()