1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
|
# mypy: allow-untyped-defs
import functools
import inspect
import logging
import math
import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
from ..._dynamo.utils import counters
from ..pattern_matcher import (
filter_nodes,
fwd_only,
gen_register_replacement,
joint_fwd_bwd,
)
log = logging.getLogger(__name__)
aten = torch.ops.aten
if torch.version.hip:
def _scaled_dot_product_attention(*args, **kwargs):
with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]):
return aten.scaled_dot_product_attention(*args, **kwargs)
else:
_scaled_dot_product_attention = aten.scaled_dot_product_attention
def _sfdp_pattern_1(query, key, value, inv_scale):
return (
torch.matmul(query, key.transpose(-2, -1))
.div(inv_scale)
.softmax(dim=-1)
.matmul(value)
)
def _sfdp_replacement_1(query, key, value, inv_scale):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=1.0 / inv_scale,
)
def _sfdp_pattern_2(query, key, value, scale_factor):
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(scale_factor)
.softmax(dim=-1)
.matmul(value)
)
def _sfdp_replacement_2(query, key, value, scale_factor):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=scale_factor,
)
def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p):
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1))
.div(inv_scale_factor)
.softmax(dim=-1),
p=dropout_p,
).matmul(value)
def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=dropout_p,
is_causal=False,
scale=1.0 / inv_scale_factor,
)
def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p):
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1),
p=dropout_p,
).matmul(value)
def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=dropout_p,
is_causal=False,
scale=scale_factor,
)
def _sfdp_pattern_5(query, key, value, attn_mask):
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
)
# attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ value
def _sfdp_replacement_5(query, key, value, attn_mask):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask.to(dtype=query.dtype),
dropout_p=0.0,
is_causal=False,
)
def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p):
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1
)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value
def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask.to(dtype=query.dtype),
dropout_p=dropout_p,
is_causal=False,
)
def _sfdp_pattern_7(query, key, value, dropout_p):
# in real workloads inputs to matmul are permuted
# causing matmul to expand to a series of expand and clone calls
# we want the same to happen during pattern tracing
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
def _sfdp_replacement_7(query, key, value, dropout_p):
# sdpa prefers inputs in permuted format
# it makes a copy to put them in this format
# if they aren't already
# to make replacement efficient ensure that inputs to sdpa
# are in required order
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return _scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=dropout_p,
is_causal=False,
)
def _sfdp_pattern_8(query, key, value):
# no dropout version of pattern 7
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
def _sfdp_replacement_8(query, key, value):
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return _scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=0.0,
is_causal=False,
)
def _sfdp_pattern_9(query, key, value, dropout_p):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
def _sfdp_replacement_9(query, key, value, dropout_p):
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return _scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=dropout_p,
is_causal=False,
)
def _sfdp_pattern_10(query, key, value):
# no dropout version of 9
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v
def _sfdp_replacement_10(query, key, value):
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return _scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=0.0,
is_causal=False,
)
def _sfdp_pattern_11(query, key, value, inv_scale):
# Mainly for huggingface models
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v)
def _sfdp_replacement_11(query, key, value, inv_scale):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=1.0 / inv_scale,
)
def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return torch.nn.functional.dropout(
torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1),
p=dropout_p,
).matmul(v)
def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=None,
dropout_p=dropout_p,
is_causal=False,
scale=1.0 / inv_scale_factor,
)
def _sfdp_pattern_13(query, key, value, dropout_p):
attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1)
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p)
return torch.bmm(attn_weight, value)
def _sfdp_replacement_13(query, key, value, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
dropout_p=dropout_p,
scale=1.0,
).squeeze(0)
def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale):
# for BertLarge
# Permutations are needed to create clones in graph.
q = query.permute([0, 2, 1, 3])
k = key.permute([0, 2, 1, 3])
v = value.permute([0, 2, 1, 3])
return (
(torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask)
.softmax(dim=-1)
.matmul(v)
)
def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attn_mask.to(dtype=query.dtype),
dropout_p=0.0,
is_causal=False,
scale=1.0 / inv_scale,
)
def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale):
# for DistilBert
# Permutations are needed to create clones in graph.
# Ref: https://github.com/pytorch/pytorch/issues/119911
q = query.permute([0, 2, 1, 3])
k = key.permute([0, 2, 1, 3])
v = value.permute([0, 2, 1, 3])
bs = q.size(0)
k_len = k.size(-2)
scores = q @ k.transpose(-2, -1)
scores = scores.div(inv_scale)
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v
def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale):
counters["inductor"]["fuse_attention"] += 1
bs = query.size(0)
n_head = query.size(2)
q_len = query.size(1)
k_len = key.size(1)
# do attn_mask->logical_not() in _scaled_dot_product_attention
attn_mask = (
(attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
)
return _scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attn_mask.to(dtype=torch.bool),
dropout_p=0.0,
is_causal=False,
scale=1.0 / inv_scale,
)
def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p):
# for BertLarge with dropout
q = query.permute([0, 2, 1, 3])
k = key.permute([0, 2, 1, 3])
v = value.permute([0, 2, 1, 3])
return (
torch.nn.functional.dropout(
(torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax(
dim=-1
),
dropout_p,
)
.to(dtype=query.dtype)
.matmul(v)
)
def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p):
counters["inductor"]["fuse_attention"] += 1
return _scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attn_mask.to(dtype=query.dtype),
dropout_p=dropout_p,
is_causal=False,
scale=1.0 / inv_scale,
)
def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p):
# for DistilBert with dropout
q = query.permute([0, 2, 1, 3])
k = key.permute([0, 2, 1, 3])
v = value.permute([0, 2, 1, 3])
bs = q.size(0)
k_len = k.size(-2)
scores = q @ k.transpose(-2, -1)
scores = scores.div(inv_scale)
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores)
return (
torch.nn.functional.dropout(
torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p
)
@ v
)
def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p):
counters["inductor"]["fuse_attention"] += 1
bs = query.size(0)
n_head = query.size(2)
q_len = query.size(1)
k_len = key.size(1)
# do attn_mask->logical_not() in _scaled_dot_product_attention
attn_mask = (
(attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len))
)
return _scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attn_mask.to(dtype=torch.bool),
dropout_p=dropout_p,
is_causal=False,
scale=1.0 / inv_scale,
)
def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p):
# for hf_GPT2 with dropout (introduces clone node) for inference
# it also returns permuted key & value
query = query.permute([0, 2, 1, 3])
key = key.permute([0, 2, 1, 3])
value = value.permute([0, 2, 1, 3])
attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
inv_scale = torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
attn_weights = attn_weights.div(inv_scale)
causal_mask_value = torch.full(
(), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
)
attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
return (
(
torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul(
value
)
),
key,
value,
)
def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p):
counters["inductor"]["fuse_attention"] += 1
permuted_key = key.transpose(1, 2)
permuted_value = value.transpose(1, 2)
return (
_scaled_dot_product_attention(
query.transpose(1, 2),
permuted_key,
permuted_value,
attn_mask=causal_mask,
dropout_p=dropout_p,
is_causal=False,
scale=1.0 / math.sqrt(value.size(-1)),
),
permuted_key,
permuted_value,
)
def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p):
# for token-classification+gpt2 / text-generation+gpt2
attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2))
inv_scale = torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
attn_weights = attn_weights.div(inv_scale)
causal_mask_value = torch.full(
(), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device
)
attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value)
attn_weights = attn_weights + attn_mask
attn_weights = attn_weights.softmax(dim=-1).type(value.dtype)
return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value)
def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p):
counters["inductor"]["fuse_attention"] += 1
fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device)
attn_mask = torch.where(causal_mask, attn_mask, fill_value)
return _scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=False,
scale=1.0 / math.sqrt(value.size(-1)),
)
def _sfdp_params_check(match):
assert all(k in match.kwargs for k in ("query", "key", "value"))
query = match.kwargs["query"].meta["val"]
key = match.kwargs["key"].meta["val"]
value = match.kwargs["value"].meta["val"]
if not (query.dtype == key.dtype == value.dtype) or not (
query.device == key.device == value.device
):
return False
add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
# Has attn_mask add.
if len(add_mask_node) > 0:
attn_mask_node = add_mask_node[0].args[1]
# attn_mask_node may be a float/int number.
if not hasattr(attn_mask_node, "meta"):
return False
attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr]
# Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool
# attn_mask.dtype == torch.float for models like albert.
if (
not isinstance(attn_mask, torch.Tensor)
or not (
attn_mask.dtype == query.dtype
or attn_mask.dtype == torch.bool
or attn_mask.dtype == torch.float
)
or query.device != attn_mask.device
# When we tensorify floats we end up turning floats
# into 0d scalar tensors. It doesn't make any sense
# to have a 0d scalar tensor attention mask so
# conveniently we can insert this check to get
# tests that erroneously passing in a float
# attention mask to fail as expected.
or attn_mask.dim() == 0
):
return False
return True
def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False):
def fn(match):
if (
disable_cuda
and "query" in match.kwargs
and "cuda" in str(match.kwargs["query"].meta["val"].device)
):
return False
if scale_factor_op is not None:
scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0]
# Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns.
scale_factor = scale_factor_node.args[1]
# make sure the scale_factor a float/int. SymInt?
if not isinstance(scale_factor, (float, int)):
return False
return _sfdp_params_check(match)
return fn
def partialize_and_update_signature(func, **kwargs):
"""
Equivalent to functools.partial but also updates the signature on returned function
"""
original_sig = inspect.signature(func)
parameters = original_sig.parameters
new_parameters = {
key: value for key, value in parameters.items() if key not in kwargs
}
new_sig = inspect.Signature(parameters=list(new_parameters.values()))
partial_func = functools.partial(func, **kwargs)
def wrapper(*args, **kwargs):
return partial_func(*args, **kwargs)
wrapper.__signature__ = new_sig # type: ignore[attr-defined]
wrapper.__name__ = func.__name__
return wrapper
def _get_sfdp_patterns():
from .joint_graph import patterns
if torch.cuda.is_available():
# workaround https://github.com/pytorch/pytorch/issues/97894
device = "cuda"
else:
device = "cpu"
# sizes/values don't actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
g_inp = functools.partial(
torch.empty, (2, 4, 8, 16), device=device, requires_grad=True
)
# attn_mask
b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
# inv_scale
c_inp = functools.partial(torch.tensor, 2.0, device=device)
# workaround https://github.com/pytorch/pytorch/issues/97894
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
d = {"dropout_p": 0.113377}
# we could also generate all these patterns in 3d.. TODO
g_3d_inp = functools.partial(
torch.empty, (1024, 128, 128), device=device, requires_grad=True
)
# reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change.
# however when batch_size=1, reshape does not change the memory layout, so clone would not be generated.
# here we need to trace with input of batch_size=1 to generate a pattern graph without clone.
g_bs1_inp = functools.partial(
torch.empty, (1, 4, 8, 16), device=device, requires_grad=True
)
m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device)
# softmax will generate a dtype conversion on inputs if they are in half,
# but will not in float, so we generate a pattern for both
for dtype in [torch.float, torch.half]:
g = functools.partial(g_inp, dtype=dtype)
b = functools.partial(b_inp, dtype=dtype)
b_float = functools.partial(b_inp, dtype=torch.float)
b_bool = functools.partial(b_inp, dtype=torch.bool)
m = functools.partial(m_inp, dtype=dtype)
m_float = functools.partial(m_inp, dtype=torch.float)
m_bool = functools.partial(m_inp, dtype=torch.bool)
c = functools.partial(c_inp, dtype=dtype)
g_3d = functools.partial(g_3d_inp, dtype=dtype)
g_bs1 = functools.partial(g_bs1_inp, dtype=dtype)
m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool)
candidates = [
(
_sfdp_pattern_1,
_sfdp_replacement_1,
[g(), g(), g(), c()],
{},
_sfdp_extra_check(aten.div.Tensor),
),
(
_sfdp_pattern_2,
_sfdp_replacement_2,
[g(), g(), g(), c()],
{},
_sfdp_extra_check(aten.mul.Tensor),
),
(
_sfdp_pattern_3,
_sfdp_replacement_3,
[g(), g(), g(), c()],
d,
_sfdp_extra_check(aten.div.Tensor),
),
(
_sfdp_pattern_4,
_sfdp_replacement_4,
[g(), g(), g(), c()],
d,
_sfdp_extra_check(aten.mul.Tensor),
),
(
_sfdp_pattern_5,
_sfdp_replacement_5,
[g(), g(), g(), b()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_6,
_sfdp_replacement_6,
[g(), g(), g(), b()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_7,
_sfdp_replacement_7,
[g(), g(), g()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_8,
_sfdp_replacement_8,
[g(), g(), g()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_9,
_sfdp_replacement_9,
[g(), g(), g()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_10,
_sfdp_replacement_10,
[g(), g(), g()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_11,
_sfdp_replacement_11,
[g(), g(), g(), c()],
{},
_sfdp_extra_check(aten.div.Tensor),
),
(
_sfdp_pattern_12,
_sfdp_replacement_12,
[g(), g(), g(), c()],
d,
_sfdp_extra_check(aten.div.Tensor),
),
(
_sfdp_pattern_13,
_sfdp_replacement_13,
[g_3d(), g_3d(), g_3d()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_14,
_sfdp_replacement_14,
[g(), g(), g(), m(), c()],
{},
_sfdp_extra_check(aten.div.Tensor),
),
(
_sfdp_pattern_15,
_sfdp_replacement_15,
[g(), g(), g(), m(), c()],
{},
_sfdp_extra_check(aten.div.Tensor),
),
# TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention
(
_sfdp_pattern_16,
_sfdp_replacement_16,
[g(), g(), g(), m(), c()],
d,
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
),
(
_sfdp_pattern_16,
_sfdp_replacement_16,
[g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()],
d,
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
),
(
_sfdp_pattern_17,
_sfdp_replacement_17,
[g(), g(), g(), m(), c()],
d,
_sfdp_extra_check(aten.div.Tensor),
),
(
_sfdp_pattern_18,
_sfdp_replacement_18,
[g(), g(), g(), m_bool()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_18,
_sfdp_replacement_18,
[g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_19,
_sfdp_replacement_19,
[g(), g(), g(), b_bool(), b_float()],
d,
_sfdp_params_check,
),
]
mask_fp32_patterns = ["pattern_16"]
if dtype == torch.half:
# Add inputs of bf16 q/k/v and fp32 mask, for models like albert.
candidates.append(
(
_sfdp_pattern_16,
_sfdp_replacement_16,
[g(), g(), g(), m_float(), c()],
d,
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
)
)
candidates.append(
(
_sfdp_pattern_16,
_sfdp_replacement_16,
[g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()],
d,
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
)
)
for pattern, replacement, args, workaround, extra_check in candidates:
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
# gets serialized to a python file and does not require tracing at runtime.
assert isinstance(workaround, dict)
name = pattern.__name__
if dtype != torch.float:
name += "_half"
if (
any(p in name for p in mask_fp32_patterns)
and args[3].dtype == torch.float32
):
name += "_mask_fp32"
if args[0].size(0) == 1:
name += "_bs1"
training_name = name + "_training"
yield training_name, {
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": joint_fwd_bwd,
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
if workaround:
assert len(workaround) == 1 and "dropout_p" in workaround
# functools.partial insufficient because we look at signature downstream
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
replacement = partialize_and_update_signature(
replacement, dropout_p=0.0
)
workaround = {}
inference_name = name + "_inference"
yield inference_name, {
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": fwd_only,
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
# with dropout turned into clone, we end up with a number of
# semantically identical graphs
"skip_duplicates": True,
}
@functools.lru_cache(None)
def _sfdp_init():
for key, register_replacement_kwargs in _get_sfdp_patterns():
gen_register_replacement(key, **register_replacement_kwargs)
|