1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
|
# Owner(s): ["module: inductor"]
import functools
import unittest
import torch
from torch import Tensor
from torch._inductor import config, utils
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.utils._triton import has_triton_tma_device
torch.set_float32_matmul_precision("high")
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS: float = 1e-12
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
else:
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)
@torch.no_grad()
def _amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
) -> torch.Tensor:
# To make scale dtype to be fp32 for accuracy
amax = amax.float()
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
return res
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x))
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
@instantiate_parametrized_tests
class TestFP8Types(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype):
"""
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
depends on the variant of fp8 type.
This cause triton_heuristics.triton_config pick a XBLOCK larger
than numel and fail the config sanity check.
We should not pick a XBLOCK larger than xnumel
"""
def f(x):
return x.to(dtype=float8_dtype)
x = torch.randn(1, device="cuda")
expected = f(x)
actual = torch.compile(f)(x)
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
@parametrize("dtype", (torch.float16, torch.bfloat16))
def test_eager_fallback(self, dtype: torch.dtype):
weight_shape = (32, 16)
e4m3_type = (
torch.float8_e4m3fn if torch.version.hip is None else torch.float8_e4m3fnuz
)
def fp8_matmul_unwrapped(x):
a_scale = torch.Tensor([1.0]).to(device="cuda")
b_scale = torch.Tensor([1.0]).to(device="cuda")
output_scale = None
input_bias = torch.rand(32, device="cuda", dtype=dtype)
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
e4m3_type
)
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale
output = torch._scaled_mm(
x,
weight,
bias=input_bias,
out_dtype=dtype,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
)
return output
compiled_fp8_matmul = torch.compile(
fp8_matmul_unwrapped, backend="inductor", dynamic=True
)
x_shape = (16, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x)
x_shape = (15, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("shape", ("15,3,13", "4,2048,4096"))
@parametrize(
"dst_types",
[(torch.float8_e4m3fn, torch.float8_e5m2)]
if torch.version.hip is None
else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
)
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
e4m3, e5m2 = dst_types
def fp8_cast(x):
y0 = x.to(dtype=e4m3).to(dtype)
y1 = x.to(dtype=e5m2).to(dtype)
return y0, y1
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cuda", dtype=dtype)
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
torch.testing.assert_close(y1_fp8, x, rtol=5e-1, atol=5e-1)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_bad_cast(self):
def fp8_cast(x, dtype):
return x.to(dtype=dtype)
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
x_shape = (16, 16, 16)
with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed,
"Conversions between float8_e5m2 and float8_e4m3fn is not supported!",
):
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e4m3fn)
y = compiled_fp8_cast(x, torch.float8_e5m2)
with self.assertRaisesRegex(
torch._dynamo.exc.BackendCompilerFailed,
"Conversions between float8_e5m2 and float8_e4m3fn is not supported!",
):
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
y = compiled_fp8_cast(x, torch.float8_e4m3fn)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize(
"dst_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("16,16,16", "4,2048,4096"))
def test_to_fp8_saturated(
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
):
def fp8_saturated(x, dtype):
return _to_fp8_saturated(x, dtype)
compiled_fp8_cast = torch.compile(
fp8_saturated, backend="inductor", dynamic=True
)
shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cuda", dtype=src_dtype)
y_compiled = compiled_fp8_cast(x, dst_dtype)
y = fp8_saturated(x, dst_dtype)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
def amax_fp8(x: Tensor, scale: Tensor):
y = torch.amax(torch.abs(x))
y_scaled = y.to(dtype=torch.float) * scale
bits_fp8 = _to_fp8_saturated(y_scaled, float8_dtype)
return bits_fp8
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
y_compiled = compiled_amax_fp8_quant(x, scale)
y = amax_fp8(x, scale)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
amax_buffer.fill_(torch.amax(torch.abs(x)))
x_scaled = x.to(dtype=torch.float) * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(1.0, device="cuda", dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
y = amax_fp8(x, scale, amax_buffer)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
torch.testing.assert_close(
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("amax_keep_dim", (True, False))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_layernorm_fp8_quant(
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
):
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
amax_buffer.fill_(
torch.amax(torch.abs(x), keepdim=amax_keep_dim).reshape(-1)[0]
)
x_scaled = x * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
y = ln_fp8(x, scale, amax_buffer)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
torch.testing.assert_close(
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("4,2048,4096",))
@parametrize("keepdim", (False, True))
def test_layernorm_fp8_quant_benchmark(
self,
float8_dtype: torch.dtype,
shape: str,
keepdim: bool,
):
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
def ln(x: Tensor):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
return x
def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
x = torch.nn.functional.layer_norm(
x.to(dtype=torch.float),
[hidden_size],
weight=None,
bias=None,
eps=1e-05,
)
amax = torch.amax(torch.abs(x), keepdim=keepdim)
amax_buffer.view_as(amax).copy_(amax)
x_scaled = x * scale
bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
return bits_fp8
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
_ = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
compiled_latency = utils.do_bench_using_profiling(
functools.partial(compiled_ln_fp8_quant, x, scale, amax_buffer_compiled)
)
eager_latency = utils.do_bench_using_profiling(
functools.partial(ln_fp8, x, scale, amax_buffer)
)
compiled_ln = torch.compile(ln, backend="inductor")
_ = compiled_ln(x)
ln_latency = utils.do_bench_using_profiling(functools.partial(compiled_ln, x))
print(
f"Config: {float8_dtype=}, {shape=}, {keepdim=}. "
f"Benchmark results: Inductor: {compiled_latency}ms, Eager: {eager_latency}ms, "
f"LN only Inductor: {ln_latency}ms."
)
@instantiate_parametrized_tests
class TestFP8Lowering(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("dtype", (torch.bfloat16, torch.float32))
@parametrize("shape", ("16,16,32", "1024,1024,512"))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_tensorwise_scaling(
self,
dtype: torch.dtype,
shape: str,
has_bias: bool,
use_fast_accum: bool,
persistent_matmul: bool,
):
if dtype is torch.float32 and has_bias:
self.skipTest("bias is not supported when output dtype is float32")
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
shape = [int(dim) for dim in shape.split(",")]
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
# input and output dtypes of _scaled_mm do not need to be the same, but
# typically in a model they are
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
if has_bias:
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
w_t_fp8 = w_fp8.t()
# quantize input x
x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
# autotuning for the compiled case, the results can be different because of
# the way blocks of results are accumulated (float addition not associative), so
# setting a small absolute tolerance in these tests
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("shape", ("16,16,32", "1024,1024,512"))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_rowwise_scaling(
self, shape: str, has_bias: bool, use_fast_accum: bool, persistent_matmul: bool
):
# Only bf16 output type is supported for row-wise scaling, not fp32
dtype: torch.dtype = torch.bfloat16
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
shape = [int(dim) for dim in shape.split(",")]
M, K, N = shape # Matmul Y = X [M, K] x W [N, K]
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
if has_bias:
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 1024))
@parametrize("N", (16, 2048))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_tensorwise_scaling_acceptable_input_dims(
self, M: int, K: int, N: int, persistent_matmul: bool
):
# alignment requirements: K and N divisible by 16
dtype: torch.dtype = torch.bfloat16
use_fast_accum = True
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
w_t_fp8 = w_fp8.t()
x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 1024))
@parametrize("N", (16, 2048))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
)
def test_rowwise_scaling_acceptable_input_dims(
self, M: int, K: int, N: int, persistent_matmul: bool
):
dtype: torch.dtype = torch.bfloat16
use_fast_accum = True
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=use_fast_accum,
)
return y
y_eager = linear(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
y_compiled = linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
def test_unacceptable_input_dims(self):
# for compiled ops, type checking is in torch/_meta_registrations.py
dtype: torch.dtype = torch.bfloat16
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
M, K, N = 64, 15, 2048 # K needs to be a multiple of 16
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
w_t_fp8 = w_fp8.t()
def linear(x, w_t_fp8, w_inverse_scale, bias):
x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=True,
)
return y
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
with self.assertRaises(torch._dynamo.exc.TorchRuntimeError) as cm:
y_compiled = linear_compiled(
x,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertTrue(
f"Expected self.size(1) to be divisible by 16, but got self.size(1)={K}"
in str(cm.exception)
)
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
def test_unacceptable_scale_dims_rowwise_scaling(self):
dtype: torch.dtype = torch.bfloat16
device = "cuda"
dtype_float8 = torch.float8_e4m3fn
M, K, N = 233, 32, 128
x = torch.randn(M, K, dtype=dtype, device=device)
w = torch.randn(N, K, dtype=dtype, device=device)
bias = torch.randn(N, device=device, dtype=torch.bfloat16)
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
w_t_fp8 = w_fp8.t()
def linear(x, w_t_fp8, w_inverse_scale, bias):
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
w_inverse_scale.t(), # testing with w and x scales switched
x_inverse_scale,
bias,
out_dtype=dtype,
use_fast_accum=True,
)
return y
linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
with self.assertRaises(torch._dynamo.exc.TorchRuntimeError) as cm:
y_compiled = linear_compiled(
x,
w_t_fp8,
w_inverse_scale,
bias,
)
self.assertTrue("Invalid scaling configuration." in str(cm.exception))
if __name__ == "__main__":
if HAS_CUDA:
run_tests()
|