1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
|
# Owner(s): ["module: inductor"]
import logging
import os
import unittest
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
except ImportError:
from test_aot_inductor_utils import AOTIRunnerUtil
import torch
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
try:
from .test_fp8 import _quantize_rowwise, _quantize_tensorwise
except ImportError:
from test_fp8 import _quantize_rowwise, _quantize_tensorwise
torch.set_float32_matmul_precision("high")
if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
log = logging.getLogger(__name__)
def _get_path_without_sccache() -> str:
"""
Get the PATH environment variable without sccache.
"""
path_envs = os.environ.get("PATH", "").split(":")
path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
return ":".join(path_envs)
@instantiate_parametrized_tests
class TestCKBackend(TestCase):
def setUp(self):
# The new inductor cache refresh mechanism
# introduced with https://github.com/pytorch/pytorch/pull/122661
# interacts badly with persistent subprocesses during
# autotuning. So we need to disable automatic cache refresh
# before calling setUp() on the parent class.
old_disable_fresh_cache_envvar = os.environ.get(
"INDUCTOR_TEST_DISABLE_FRESH_CACHE", ""
)
torch.random.manual_seed(1234)
try:
import ck4inductor # @manual
self.ck_dir = os.path.dirname(ck4inductor.__file__)
os.environ["TORCHINDUCTOR_CK_DIR"] = self.ck_dir
except ImportError as e:
raise unittest.SkipTest("Composable Kernel library not installed") from e
try:
os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1"
super().setUp()
finally:
os.environ[
"INDUCTOR_TEST_DISABLE_FRESH_CACHE"
] = old_disable_fresh_cache_envvar
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
@parametrize("autotune_in_subproc", (True, False))
@parametrize("use_aoti", (True, False))
def test_max_autotune_precompile_matmul(
self, max_autotune_gemm_backends, autotune_in_subproc, use_aoti
):
"""
Make sure autotuning mm doesn't crash.
"""
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
tensor_options = {"device": "cuda", "dtype": torch.bfloat16}
a = torch.randn(2240, 256, **tensor_options)
b = torch.randn(256, 2048, **tensor_options)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": autotune_in_subproc,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 2,
"rocm.n_max_profiling_configs": 2,
"rocm.ck_dir": self.ck_dir,
}
):
if use_aoti:
Y_compiled = AOTIRunnerUtil.run(
device="cuda",
model=mm,
example_inputs=(a, b),
)
else:
@torch.compile(dynamic=False)
def compiled_mm(x, w):
return mm(x, w)
Y_compiled = compiled_mm(a, b)
Y = mm(a=a, b=b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK",))
@parametrize("autotune_in_subproc", (True,))
def test_max_autotune_precompile_matmul_dynamic(
self, max_autotune_gemm_backends, autotune_in_subproc
):
"""
Test matmul with dynamic shapes
"""
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
tensor_options = {"device": "cuda", "dtype": torch.bfloat16}
a = torch.randn(2240, 256, **tensor_options)
b = torch.randn(256, 2048, **tensor_options)
torch._dynamo.mark_dynamic(a, 0)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": autotune_in_subproc,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 2,
"rocm.n_max_profiling_configs": 2,
"rocm.ck_dir": self.ck_dir,
}
):
@torch.compile(dynamic=True)
def compiled_mm(a, b):
return a @ b
Y_compiled = compiled_mm(a, b)
Y = a @ b
torch.testing.assert_close(Y_compiled, Y)
a1 = torch.randn(1024, 256, **tensor_options)
Y1_compiled = compiled_mm(a1, b)
Y1 = a1 @ b
torch.testing.assert_close(Y1_compiled, Y1)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
def test_max_autotune_precompile_preselected(self, max_autotune_gemm_backends):
"""
End to end test for picking preselected ck instances
"""
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
tensor_options = {"device": "cuda", "dtype": torch.float16}
a = torch.randn(2240, 256, **tensor_options)
b = torch.randn(2048, 256, **tensor_options).transpose(0, 1)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 12,
"rocm.ck_dir": self.ck_dir,
"rocm.use_preselected_instances": True,
}
):
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends):
"""
Make sure the ck template can work with non-contiguous inputs
"""
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
tensor_options = {"device": "cuda", "dtype": torch.float16}
a = torch.empty_strided((50257, 32768), (1, 50304), **tensor_options)
b = torch.empty_strided((32768, 768), (768, 1), **tensor_options)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 2,
"rocm.ck_dir": self.ck_dir,
"rocm.n_max_profiling_configs": 2,
}
):
@torch.compile(dynamic=False)
def mm(a, b):
return a @ b
Y_compiled = mm(a, b)
Y_eager = a @ b
torch.testing.assert_close(Y_compiled, Y_eager)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
@parametrize("x_shape", ([4096, 2048], [2048], [4096, 1]))
def test_max_autotune_addmm(self, max_autotune_gemm_backends, x_shape):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
m, k, n = 4096, 224, 2048
alpha, beta = 1.0, 1.0
tensor_options = {"device": "cuda", "dtype": torch.float16}
x = torch.ones(x_shape, **tensor_options)
a = torch.randn(m, k, **tensor_options)
b = torch.randn(k, n, **tensor_options)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 2,
"rocm.ck_dir": self.ck_dir,
"rocm.n_max_profiling_configs": 2,
}
):
@torch.compile(dynamic=False)
def addmm(x, a, b, alpha, beta):
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
Y_compiled = addmm(x, a, b, alpha, beta)
Y_eager = torch.addmm(x, a, b, alpha=alpha, beta=beta)
torch.testing.assert_close(Y_compiled, Y_eager)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
@parametrize("dtype", (torch.bfloat16,))
@parametrize("use_fast_accum", (True,))
@parametrize("quantize_type", ("tensorwise", "rowwise"))
@parametrize("has_bias", (True, False))
def test_max_autotune_scaled_mm(
self, max_autotune_gemm_backends, dtype, use_fast_accum, quantize_type, has_bias
):
tensor_options = {"device": "cuda", "dtype": dtype}
M = 2240
N = 2048
K = 256
x = torch.randn(M, K, **tensor_options)
w = torch.randn(N, K, **tensor_options)
bias = None
if has_bias:
bias = torch.randn(N, **tensor_options)
dtype_float8 = torch.float8_e4m3fnuz
f_quantize = (
_quantize_tensorwise if quantize_type == "tensorwise" else _quantize_rowwise
)
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = f_quantize(w, dtype_float8)
w_t_fp8 = w_fp8.t()
w_inverse_scale_t = w_inverse_scale.t()
# quantize input x
x_fp8, x_inverse_scale = f_quantize(x, dtype_float8)
assert "rocm" in dir(config)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
if quantize_type == "tensorwise":
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale_t,
bias,
)
else:
# FIXME when rowwise quantize is supported by pt eager on ROCm
w_fp8_tw, w_inverse_scale_tw = _quantize_tensorwise(w, dtype_float8)
w_fp8_tw_t = w_fp8_tw.t()
w_inverse_scale_tw_t = w_inverse_scale_tw.t()
x_fp8_tw, x_inverse_scale_tw = _quantize_tensorwise(x, dtype_float8)
y_eager = linear(
x_fp8_tw,
x_inverse_scale_tw,
w_fp8_tw_t,
w_inverse_scale_tw_t,
bias,
)
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 24,
"rocm.n_max_profiling_configs": 24,
"rocm.ck_dir": self.ck_dir,
}
):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale_t,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(
os.environ,
{"PATH": _get_path_without_sccache(), "PYTORCH_MIOPEN_SUGGEST_NHWC": "1"},
)
@parametrize("max_autotune_conv_backends", ("CK", "ATEN,CK,TRITON"))
def test_max_autotune_conv2d(self, max_autotune_conv_backends):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
tensor_options = {"device": "cuda", "dtype": torch.float32}
x = torch.randn(1, 8, 224, 224, **tensor_options)
w = torch.randn(64, 8, 7, 7, **tensor_options)
x_cl = x.to(memory_format=torch.channels_last)
w_cl = w.to(memory_format=torch.channels_last)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_conv_backends": max_autotune_conv_backends,
"compile_threads": 4,
"rocm.ck_dir": self.ck_dir,
"rocm.n_max_profiling_configs": 4,
}
):
@torch.compile(dynamic=False)
def conv2d(x, w):
return torch.conv2d(x, w)
Y_eager = torch.conv2d(x_cl, w_cl)
Y_compiled = conv2d(x_cl, w_cl)
torch.testing.assert_close(Y_compiled, Y_eager, atol=2e-4, rtol=2e-4)
@unittest.skipIf(not torch.version.hip, "ROCM only")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK"))
def test_max_autotune_precompile_bmm(
self,
max_autotune_gemm_backends,
):
"""
Test gemm-max-autotune torch.bmm with CK backend
"""
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def bmm(a, b):
return torch.bmm(a, b)
tensor_options = {"device": "cuda", "dtype": torch.bfloat16}
a = torch.randn(16, 2240, 256, **tensor_options)
b = torch.randn(16, 2048, 256, **tensor_options).transpose(1, 2)
assert "rocm" in dir(config)
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"compile_threads": 2,
"rocm.n_max_profiling_configs": 2,
"rocm.ck_dir": self.ck_dir,
}
):
@torch.compile(dynamic=False)
def compiled_bmm(x, w):
return bmm(x, w)
Y_compiled = compiled_bmm(a, b)
Y_eager = bmm(a=a, b=b)
torch.testing.assert_close(Y_compiled, Y_eager)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
# Set env to make it work in CI.
if HAS_CUDA and HAS_CPU and is_big_gpu():
run_tests()
|