1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
# Owner(s): ["module: inductor"]
import os
import unittest
import torch
import torch._inductor.config as inductor_config
from torch._dynamo.device_interface import get_interface_for_device
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import get_gpu_shared_memory
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_A100, IS_H100
@skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack")
class AutoHeuristicTest(TestCase):
def count_lines_in_file(self, file_path):
with open(file_path) as file:
line_count = sum(1 for line in file)
return line_count
def run_mm(self):
def f(a, b):
return torch.mm(a, b)
cf = torch.compile(f)
a = torch.randn(2047, 2048, device=GPU_TYPE, dtype=torch.float16)
b = torch.randn(2048, 2048, device=GPU_TYPE, dtype=torch.float16)
cf(a, b)
def get_path_to_autoheuristic_log(self, name):
device_name = AutoHeuristic.get_device_identifier()
path = cache_dir() + "/autoheuristic/" + device_name + "/" + name + ".txt"
return path
def test_autoheuristic_pad_mm_default(self):
# this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value
self.run_mm()
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
@inductor_config.patch(autoheuristic_collect="foo")
def test_autoheuristic_pad_mm_off(self):
# this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm"
self.run_mm()
self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm")))
def assert_autoheuristic_collected_data(self):
self.run_mm()
device_name = AutoHeuristic.get_device_identifier()
path = self.get_path_to_autoheuristic_log("pad_mm")
self.assertTrue(os.path.exists(path))
num_lines = self.count_lines_in_file(path)
# 1 line for metadata, 1 line for header, 1 line per choice (orig, padded)
self.assertEqual(num_lines, 4)
@inductor_config.patch(autoheuristic_collect="pad_mm")
def test_autoheuristic_pad_mm_collect_data(self):
# this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm"
self.assert_autoheuristic_collected_data()
@inductor_config.patch(autoheuristic_collect="foo,pad_mm")
def test_autoheuristic_pad_mm_collect_data2(self):
# this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm"
self.assert_autoheuristic_collected_data()
@inductor_config.patch(autoheuristic_collect="test")
def test_autoheuristic(self):
# test basic functionality of autoheuristic
def fallback():
return "fallback"
choices = ["a", "b", "c"]
def feedback_fn(choice):
if choice == "a":
return 1
elif choice == "b":
return 2
elif choice == "c":
return 3
else:
raise RuntimeError("unexpected choice")
feedback = LocalFeedback(feedback_fn)
context = AHContext()
context.add_feature("fa", 5)
name = "test"
autoheuristic = AutoHeuristic(fallback, choices, feedback, context, name)
# when autoheuristic is configured to only collect data, we always return fallback
self.assertEqual(autoheuristic.get_choice(), "fallback")
self.assertEqual(autoheuristic.get_collected_feedback("a"), 1)
self.assertEqual(autoheuristic.get_collected_feedback("b"), 2)
self.assertEqual(autoheuristic.get_collected_feedback("c"), 3)
path = self.get_path_to_autoheuristic_log(name)
self.assertTrue(os.path.exists(path))
num_lines = self.count_lines_in_file(path)
self.assertEqual(num_lines, 5)
shared_memory = get_gpu_shared_memory()
(fst, snd) = get_interface_for_device(GPU_TYPE).get_device_capability()
with open(path) as file:
lines = file.readlines()
self.assertTrue('"numerical_features": ["fa"]' in lines[0])
self.assertTrue('"categorical_features": []' in lines[0])
self.assertTrue(f'"shared_memory": {shared_memory}' in lines[0])
self.assertTrue(f'"device_capa": [{fst}, {snd}]' in lines[0])
self.assertTrue('"name": "test"' in lines[0])
self.assertEqual("fa,choice,feedback", lines[1].rstrip())
self.assertEqual("5,a,1", lines[2].rstrip())
self.assertEqual("5,b,2", lines[3].rstrip())
self.assertEqual("5,c,3", lines[4].rstrip())
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
@inductor_config.patch(autoheuristic_use="pad_mm")
def test_autoheuristic_a100(self):
# Make sure heuristic does not break anything
# TODO (AlnisM): Find a way to check whether heuristic is used
self.run_mm()
@unittest.skipIf(not IS_H100, "heuristic only run on H100")
@inductor_config.patch(autoheuristic_use="pad_mm")
def test_autoheuristic_h100(self):
# Make sure heuristic does not break anything
# TODO (AlnisM): Find a way to check whether heuristic is used
self.run_mm()
def run_mixed_mm(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
a = torch.randn(8, 1024, device=GPU_TYPE, dtype=torch.float16)
b = torch.randint(
-128, 127, (1024, 1024), dtype=torch.int8, device=GPU_TYPE
).t()
torch.compile(fn, mode="max-autotune-no-cudagraphs")(a, b)
# have to set autoheuristic_use="" because if autoheuristic_use="mixed_mm",
# autoheuristic creates a precompile key, puts it into the registry, and then
# a choice made by the heuristic might be added to the list of choices
# and if select_algorithm now creates a new precompile key, it will be
# different from the precompile key created by autoheuristic
@inductor_config.patch(
autoheuristic_collect="mixed_mm",
autoheuristic_use="",
fx_graph_cache=False,
fx_graph_remote_cache=False,
)
def test_global_feedback(self):
self.run_mixed_mm()
path = self.get_path_to_autoheuristic_log("mixed_mm")
self.assertTrue(os.path.exists(path))
num_lines = self.count_lines_in_file(path)
# 1 line for metadata, 1 line for header
# 1 line for fallback + at least 1 config
self.assertTrue(num_lines > 4)
@inductor_config.patch(autoheuristic_use="mixed_mm")
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
def test_mixed_mm_a100(self):
self.run_mixed_mm()
# TODO (AlnisM): Find a way to check whether heuristic is used
if __name__ == "__main__":
if HAS_GPU:
run_tests()
|