1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
|
# Owner(s): ["module: inductor"]
import contextlib
import os
import subprocess
import sys
from unittest.mock import patch
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.testing import rand_strided
from torch._inductor import config
from torch._inductor.codecache import PyCodeCache
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import xfailIfSM89
from torch.testing._internal.common_device_type import expectedFailureXPU
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class TestKernelBenchmark(TestCase):
device_type = GPU_TYPE
# to make sure the subprocess runs on the exact same path as the parent process
# we augment the PYTHONPATH env var
python_path = ""
@classmethod
def setUpClass(cls):
cls.exit_stack = contextlib.ExitStack()
cls.exit_stack.enter_context(patch.object(config, "benchmark_kernel", True))
# setup the augmented PYTHONPATH to pass to the subprocess calls
augmented_pp = ":".join(sys.path)
if os.environ.get("PYTHONPATH"):
augmented_pp = f"{os.environ.get('PYTHONPATH')}:{augmented_pp}"
cls.python_path = augmented_pp
@classmethod
def tearDownClass(cls):
cls.exit_stack.close()
def setUp(self):
super().setUp()
PyCodeCache.cache_clear()
def get_compiled_module(self):
compiled_module = None
for v in PyCodeCache.modules:
if hasattr(v, "benchmark_compiled_module"):
self.assertTrue(
compiled_module is None, "Found multiple compiled modules"
)
compiled_module = v
self.assertTrue(compiled_module is not None)
return compiled_module
def verify_compiled_kernels(self, GB_count=1):
compiled_module = self.get_compiled_module()
# now run the compiled module in subprocess and check its output
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -kc".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
).decode()
# make sure we have the bandwidth information in the output
FileCheck().check_count(
"GB/s",
GB_count,
exactly=1,
).run(bench_out)
def verify_remove_inductor_deps(self, compiled_module):
try:
out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__}".split(),
env={
**os.environ.copy(),
"TORCHINDUCTOR_DUMP_LAUNCH_PARAMS": "1",
"PYTHONPATH": self.python_path,
},
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as e:
print(
"Failed when runinng triton code with TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1",
e,
)
print(e.output.decode())
raise e
from torch.utils._get_clean_triton import get_clean_triton
cleaned_triton = get_clean_triton(
compiled_module.__file__, f"{compiled_module.__file__}.cleaned"
)
self.assertTrue("@triton_heuristics" not in cleaned_triton)
self.assertTrue(".run(" not in cleaned_triton)
try:
out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__}.cleaned".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
)
except subprocess.CalledProcessError as e:
print("Failed when when running cleaned triton", e)
print(e.output.decode())
print(cleaned_triton)
raise e
return cleaned_triton
def check_bandwidth(self, compiled_module, num_gb):
# now run the compiled module in subprocess and check its output
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -k".split(),
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONPATH": self.python_path},
).decode()
# make sure we have the bandwidth information in the output
FileCheck().check_count(
f"{num_gb} GB ",
1,
exactly=1,
).run(bench_out)
def test_pw_kernel_benchmark(self):
@torch.compile
def f(x):
return torch.sin(x) + torch.cos(x)
inp = torch.rand(2, 3).to(device=GPU_TYPE)
out = f(inp)
self.verify_compiled_kernels()
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
@fresh_inductor_cache()
def test_matmul_triton_kernel_benchmark(self):
M = 12544
N = 256
K = 64
a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(N, K, dtype=torch.float16, device=GPU_TYPE).t()
@torch.compile
def f(a, b):
return torch.relu(a @ b)
f(a, b)
self.verify_compiled_kernels()
@expectedFailureXPU
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
@fresh_inductor_cache()
def test_mm_triton_kernel_benchmark(self):
M = 2048
N = 2432
K = 1949
K_2 = 3581
a = rand_strided((M, K_2), (K_2, 1), device="cuda", dtype=torch.float16)
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float16)
@torch.compile
def f(a, b):
a_1 = torch.narrow(a, 1, 0, K)
c = torch.mm(a_1, b)
return c
f(a, b)
self.verify_compiled_kernels(GB_count=3)
# make sure we correctly generate the grid info
compiled_module = self.get_compiled_module()
with open(compiled_module.__file__) as f:
source_code = f.read()
lines = source_code.split("\n")
meta = [l for l in lines if "meta0 = {" in l]
scope = {}
from torch._inductor.kernel.mm_common import mm_grid
exec(meta[0], scope)
grid = mm_grid(M, N, scope["meta0"])
FileCheck().check_count(
f"grid={grid}",
2,
exactly=1,
).run(source_code)
def test_matmul_bandwidth_computation(self):
"""
The test does a matmul and then mul. Without max-autotune, we use
the matmul in aten. So there is a single triton kernel for mul.
The kernel we generated is like:
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
Note the in_out_ptr0 argument. It's for a 1000x1000 tensor, but it's
inplace udpated, so when computing the bandwidth, we should count
the total memory access as 2 * 1000 * 1000 * 4 = 8MB. This amount is
what this test asserts.
"""
torch.set_float32_matmul_precision("high") # suggested by a warning
@torch.compile
def f(x, y):
z = x @ y
w = z * z
return w
M, N, K = 1000, 1000, 10
x = torch.rand(M, K).to(device=GPU_TYPE)
y = torch.rand(K, N).to(device=GPU_TYPE)
out = f(x, y)
compiled_module = self.get_compiled_module()
self.check_bandwidth(compiled_module, 0.008)
def test_unused_input_bandwidth_computation(self):
M, N = 5, 1000000
@torch.compile
def f(a, b, c):
return a + c
a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
torch._dynamo.mark_dynamic(c, 0)
inputs = (a, b, c)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# num_gb = size_a + size_c + size_out
# num_gb = (5 * 1000000 + 5 * 1000000 + 5 * 1000000) * 2 / 1e9
# = 0.030
self.check_bandwidth(compiled_module, "0.030")
def test_reduction_bandwidth_computation(self):
@torch.compile
def f(a):
return torch.sum(a, dim=1)
a = torch.rand(1000, 20, 1000, dtype=torch.float16, device=GPU_TYPE)
inputs = (a,)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# num_gb = size_a + size_out
# num_gb = (1000 * 20 * 1000 + 1000 * 1000) * 2 / 1e9
# = 0.042
self.check_bandwidth(compiled_module, "0.042")
@config.patch(max_autotune=True)
def test_fused_layernorm_bandwidth_computation(self):
M, N = 10, 1000000
@torch.compile
def f(a, b, c, d):
x0 = a + b
x1 = torch.nn.functional.layer_norm(
x0, normalized_shape=(N,), weight=c, bias=d, eps=1e-05
)
x2 = torch.sigmoid(x1)
return x0 * x2
a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
d = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
inputs = (a, b, c, d)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# num_gb = size_a + size_b + size_c + size_d + size_out
# num_gb = (10 * 1000000 + 1000000 + 1000000 + 1000000 + 10 * 1000000) * 2 / 1e9
# = 0.046
self.check_bandwidth(compiled_module, "0.046")
def test_slice_add_cat_bandwidth_computation(self):
M, N = 5, 1000000
@torch.compile
def f(a, b, c):
x0 = torch.narrow(b, 1, N, N)
# broadcasting
x1 = x0 + c
return torch.cat([a, x1], dim=1)
a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
inputs = (a, b, c)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# we overestimate the size of "slice_b" due to torch.cat
# num_gp = size_a + size_slice_b + size_c + size_out
# num_gb = (5 * 1000000 + 5 * 2000000 + 1000000 + 5 * 2000000) * 2 / 1e9
# = 0.052
self.check_bandwidth(compiled_module, "0.052")
def test_slice_add_bandwidth_computation(self):
M, N = 5, 1000000
@torch.compile
def f(a, b, c):
x0 = torch.narrow(b, 1, N, N)
return a + x0 + c
a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
inputs = (a, b, c)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# num_gb = size_a + size_slice_b + size_c + out_size
# num_gb = (5 * 1000000 + 5 * 1000000 + 1000000 + 5 * 1000000) * 2 / 1e9
# = 0.032
self.check_bandwidth(compiled_module, "0.032")
def test_mm_slice_add_bandwidth_computation(self):
M, N, K = 1000, 1000, 30
@torch.compile
def f(a, b, c):
x0 = torch.mm(a, b)
x1 = torch.narrow(c, 1, 20 * N, N)
x2 = torch.narrow(c, 1, 21 * N, N)
return x0 + x1 + x2
a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
inputs = (a, b, c)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# torch.mm becomes an extern kernel, so we measure the nbytes
# for the pointwise add kernel:
# num_gb = x0 + 2 * size_slice_c + size_out
# num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9
# = 0.008
num_gb = "0.008"
if GPU_TYPE == "xpu":
# In XPU backend, mm + add + add will be fused as admm + add
# And CUDA prefer not fuse add + mm, please check in function
# `should_prefer_unfused_addmm` in torch/_inductor/fx_passes/post_grad.py
num_gb = "0.006"
self.check_bandwidth(compiled_module, num_gb)
def test_mm_slice_add_bandwidth_computation_2(self):
M, N, K = 1000, 1000, 30
@torch.compile
def f(a, b, c):
x0 = torch.mm(a, b)
x1 = torch.narrow(c, 1, 20 * N, N)
x2 = torch.narrow(c, 1, 20 * N, N)
return x0 + x1 + x2
a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
inputs = (a, b, c)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# torch.mm becomes an extern kernel, so we measure the nbytes
# for the pointwise add kernel:
# num_gb = x0 + size_slice_c + size_out
# num_gb = (1000 * 1000 + 1000 * 1000 + 1000 * 1000) * 2 / 1e9
# = 0.006
# note that we only count one size_slice_c because two accesses
# have the same index.
self.check_bandwidth(compiled_module, "0.006")
@expectedFailureXPU
@xfailIfSM89
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_slice_mm_bandwidth_computation(self):
M, N, K = 1000, 2000, 3000
@torch.compile
def f(a, b):
x = torch.narrow(a, 1, K, K)
return torch.mm(x, b)
a = torch.rand(M, 3 * K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
inputs = (a, b)
out = f(*inputs)
compiled_module = self.get_compiled_module()
# c[1000, 2000] = x[1000, 3000] @ b[3000, 2000]
# num_gb = (1000 * 2000 + 1000 * 3000 + 3000 * 2000) * 2 / 1e9
# = 0.022
self.check_bandwidth(compiled_module, "0.022")
def test_star_dep(self):
"""
Test the bandwidth estimation for StarDep
"""
@torch.compile
def f(a, b):
a[b] = 3.0
a = torch.rand(10000, 5000, device=GPU_TYPE)
b = torch.randint(
0, 10000, [20000], device=GPU_TYPE, dtype=torch.int32
).unsqueeze(1)
f(a, b)
compiled_module = self.get_compiled_module()
# 20000 * 4 = 80KB for b
# 20000 * 5000 * 4 = 200MB for a
self.check_bandwidth(compiled_module, "0.200")
def test_split_scan(self):
@torch.compile
def f(a):
return a.cumsum(-1)
a = torch.rand(10000, 5000, device=GPU_TYPE)
f(a.reshape(-1))
compiled_module = self.get_compiled_module()
# 10000 * 5000 * 4 = 200 MB for a
# Double that for output as well
self.check_bandwidth(compiled_module, "0.400")
@config.patch("triton.unique_kernel_names", True)
@config.patch(benchmark_kernel=False)
@config.patch(compile_threads=1)
def test_remove_inductor_deps(self):
@torch.compile
def f(a):
return a.cos().sin()
a = torch.randn(5, device=GPU_TYPE)
f(a)
compiled_module = self.get_compiled_module()
cleaned_triton = self.verify_remove_inductor_deps(compiled_module)
@config.patch("triton.unique_kernel_names", True)
@config.patch(benchmark_kernel=False)
@config.patch(compile_threads=1)
def test_remove_inductor_deps_multiple_kernels(self):
@torch.compile
def f(a):
a = torch.mm(a, a)
a = a.cos().sin()
a = torch.mm(a, a)
a = torch.softmax(a, dim=-1)
return a
a = torch.randn(5, 5, device=GPU_TYPE)
f(a)
compiled_module = self.get_compiled_module()
self.verify_remove_inductor_deps(compiled_module)
@config.patch("triton.unique_kernel_names", True)
@config.patch("triton.unique_kernel_names", True)
@config.patch(benchmark_kernel=False)
@config.patch(compile_threads=1)
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_remove_inductor_deps_templates(self):
@torch.compile
def f(a):
a = torch.mm(a, a)
a = a.cos()
a = torch.mm(a, a)
a = a.sin()
return a
a = torch.randn(128, 128, device=GPU_TYPE)
f(a)
compiled_module = self.get_compiled_module()
self.verify_remove_inductor_deps(compiled_module)
@config.patch("triton.unique_kernel_names", True)
@config.patch(benchmark_kernel=False)
@config.patch(compile_threads=1)
def test_remove_inductor_deps_scalar(self):
@torch.compile
def f(a, b):
return a + b
a = torch.tensor(1.0, device=GPU_TYPE)
b = torch.tensor(2.0, device=GPU_TYPE)
f(a, b)
compiled_module = self.get_compiled_module()
self.verify_remove_inductor_deps(compiled_module)
if __name__ == "__main__":
if HAS_GPU:
run_tests()
|