1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566
|
# mypy: allow-untyped-defs
import copy
import enum
import logging
import re
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from ... import ir
from ...config import cuda as inductor_cuda_config
from ...ir import (
Buffer,
ChoiceCaller,
CUDATemplateBuffer,
FixedLayout,
IRNode,
Layout,
ReinterpretView,
)
from ...utils import is_dynamic
from ...virtualized import V
from ..common import IndentedBuffer
from . import cutlass_utils
from .cuda_kernel import CUDATemplateKernel
from .cuda_template import CUTLASSTemplate
log = logging.getLogger(__name__)
# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below.
GEMM_TEMPLATE_CUTLASS_3X = r"""
{{template.header().getvalue()}}
{{template.globals().getvalue()}}
{{instance_definition}}
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT {{kernel_call_signature}} {
try {
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
{{instance_type}}::Arguments arguments;
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw,
X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}}
{{instance_type}} gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
"""
# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below.
GEMM_ARGS_CUTLASS_3X = r"""
// Initialize GemmUniversal3xInstance arguments.
arguments = {
{{template.gemm_mode()}}, // GemmUniversalMode mode
{
static_cast<coord_t>({{M}}),
static_cast<coord_t>({{N}}),
static_cast<coord_t>(K),
static_cast<coord_t>(B)
}, // ProblemShape problem_shape
{
{{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A
{
{{template.cute_int(kernel.stride(X, -2), "stride_x0")}},
{{template.cute_int(kernel.stride(X, -1), "stride_x1")}},
{{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}}
}, // StrideA dA
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B
{
{{template.cute_int(kernel.stride(W, -1), "stride_w1")}},
{{template.cute_int(kernel.stride(W, -2), "stride_w0")}},
{{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}}
}, // StrideB dB
}, // MainloopArguments mainloop
{{epilogue_arguments}},
hw_info
};
"""
# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied,
# used by the CUTLASSGemmTemplate class below.
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
// see https://tinyurl.com/4rk89z48
{
{{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
{{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C
{
{{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}},
{{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}},
{{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}}
}, // StrideC dC
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D
{
{{template.cute_int(kernel.stride(Y, -2), "stride_y0")}},
{{template.cute_int(kernel.stride(Y, -1), "stride_y1")}},
{{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}}
}, // StrideD dD
}, // EpilogueArguments epilogue
"""
# Jinja template for GEMM Kernel, used by the CUTLASS2xGemmTemplate class below.
GEMM_TEMPLATE_CUTLASS_2X = r"""
{{template.header().getvalue()}}
{{template.globals().getvalue()}}
{{instance_definition}}
// When workspace_size is not a nullptr, populates requested workspace_size and returns.
// Otherwise, computes the Gemm kernel using the given workspace ptr.
extern "C" {
PT_EXPORT {{kernel_call_signature}} {
try {
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
int64_t M = {{kernel.size(X, -2)}};
int64_t K = {{kernel.size(W, -2)}};
int64_t N = {{kernel.size(W, -1)}};
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
using coord_t = cutlass::gemm::GemmCoord::Index;
static cutlass::KernelHardwareInfo hw_info;
if (hw_info.sm_count == 0) {
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0);
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
{{instance_type}}::Arguments arguments;
{{template.render_gemm_arguments(instance_type, argument_template, epilogue_template, should_swap_xw,
X, W, Bias, Meta, Y, alpha, beta, kernel, epilogue_args)}}
{{instance_type}} gemm_op;
if (workspace_size) {
*workspace_size = gemm_op.get_workspace_size(arguments);
return 0;
}
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
{{kernel.check_not_null(X)}}
{{kernel.check_not_null(W)}}
{{kernel.check_not_null(Bias)}}
{{kernel.check_not_null(Meta)}}
{{kernel.check_not_null(Y)}}
{
auto status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
}
#endif
#ifdef CUTLASS_DEBUG_TRACE_LEVEL
#if CUTLASS_DEBUG_TRACE_LEVEL == 1
{
// Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
// we don't need a print statement, it's happening inside the function.
gemm_op.maximum_active_blocks();
}
#endif
#endif
{
auto status = gemm_op.initialize(arguments, workspace, stream);
CUTLASS_CHECK(status);
}
{
auto status = gemm_op(stream);
CUTLASS_CHECK(status);
}
}
catch (std::exception& e) {
std::cerr << "Runtime error: " << e.what() << std::endl;
return -1;
}
catch (...) {
return -1;
}
return 0;
}
}
"""
# Jinja template for Cutlass 2.x GEMM Kernel arguments, used by the CUTLASS2xGemmTemplate class below.
GEMM_ARGS_CUTLASS_2X = r"""
int64_t batch_stride_x = {{kernel.stride(X, -3)}};
int64_t row_stride_x = {{kernel.row_or_column_stride(X)}};
int64_t batch_stride_w = {{kernel.stride(W, -3)}};
int64_t row_stride_w = {{kernel.row_or_column_stride(W)}};
int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}};
int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}};
int64_t batch_stride_y = {{kernel.stride(Y, -3)}};
int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}};
// Initialize GemmUniversalInstance arguments.
arguments = {
{{template.gemm_mode()}}, // GemmUniversalMode mode
{
static_cast<coord_t>(M),
static_cast<coord_t>(N),
static_cast<coord_t>(K)
}, // GemmCoord problem_size
{{split_k if split_k > 1 else 'B'}}, // int batch_count
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue
{{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A
{{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B
{{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C
{{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D
batch_stride_x, // int64_t batch_stride_A
batch_stride_w, // int64_t batch_stride_B
batch_stride_bias, // int64_t batch_stride_C
batch_stride_y, // int64_t batch_stride_D
row_stride_x, // typename LayoutA::Stride::LongIndex lda
row_stride_w, // typename LayoutB::Stride::LongIndex ldb
row_stride_bias, // typename LayoutC::Stride::LongIndex ldc
row_stride_y, // typename LayoutC::Stride::LongIndex ldd
};
"""
GEMM_ARGS_SPARSE_CUTLASS_2X = r"""
using TensorRefA = cutlass::TensorRef<{{instance_type}}::ElementA,
{{instance_type}}::LayoutA>;
using TensorRefB = cutlass::TensorRef<{{instance_type}}::ElementB,
{{instance_type}}::LayoutB>;
using TensorRefC = cutlass::TensorRef<{{instance_type}}::ElementC,
{{instance_type}}::LayoutC>;
using TensorRefE = cutlass::TensorRef<{{instance_type}}::ElementE,
{{instance_type}}::LayoutE>;
// Note that "X" and "W" names may be misleading here. Namely, for
// sparse GEMM, the first argument is always sparse, while typically
// weight matrix, implied by name "W" will be sparse in
// applications. Thus, just remember that here: "X" refers to first
// argument, that is sparse, and "W" to second, that is dense.
TensorRefA X_ref({{template.cutlass_type_cast(X, kernel.ptr(X))}}, {{kernel.row_or_column_stride(X)}});
TensorRefB W_ref({{template.cutlass_type_cast(W, kernel.ptr(W))}}, {{kernel.row_or_column_stride(W)}});
TensorRefC Y_ref({{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, {{kernel.row_or_column_stride(Y)}});
TensorRefE Meta_ref({{template.cutlass_sparse_meta_type_cast(Meta, kernel.ptr(Meta))}},
TensorRefE::Layout::packed({ {{kernel.size(Meta, 0)}}, {{kernel.size(Meta, 1)}} }));
// Initialize GemmSparse arguments.
arguments = {
{
static_cast<coord_t>({{M}}),
static_cast<coord_t>({{N}}),
static_cast<coord_t>(K),
}, // GemmCoord problem_size
X_ref, // TensorRef<ElementA const, LayoutA> ref_A
W_ref, // TensorRef<ElementB const, LayoutB> ref_B
Y_ref, // TensorRef<ElementC const, LayoutC> ref_C
Y_ref, // TensorRef<ElementC, LayoutC> ref_D
Meta_ref, // TensorRef<ElementE const, LayoutE> ref_E
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue,
};
"""
# Additional includes which are neccessary if the standalone test / debug runner is generated as wel
GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r"""
#ifdef GENERATE_STANDALONE_RUNNER
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <iostream>
#endif
"""
# Jinja template for the standalone runner that may be generated as part of the code.
GEMM_STANDALONE_RUNNER_TEMPLATE = r"""
#ifdef GENERATE_STANDALONE_RUNNER
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed, float max=1.0, float min=-1.0) {
if (block.size()<=0) return false;
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
extern "C" int run_standalone(uint64_t seed, int repetitions) {
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
size_t workspace_size = 0;
size_t* workspace_size_ptr = &workspace_size;
using ElementA = {{kernel.cutlass_dtype(X)}};
using ElementB = {{kernel.cutlass_dtype(W)}};
using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void
using ElementD = {{kernel.cutlass_dtype(Y)}};
cutlass::DeviceAllocation<ElementA> X_data({{kernel.max_valid_index(X)+1}});
initialize_block(X_data, seed++);
cutlass::DeviceAllocation<ElementB> W_data({{kernel.max_valid_index(W)+1}});
initialize_block(W_data, seed++);
cutlass::DeviceAllocation<ElementC> Bias_data({{kernel.max_valid_index(Bias)+1}});
initialize_block(Bias_data, seed++);
cutlass::DeviceAllocation<ElementD> Y_data({{kernel.max_valid_index(Y)+1}});
cutlass::DeviceAllocation<uint8_t> workspace_data;
// Call once with workspace_size_ptr set to get workspace size
std::cout << "Calling once to get workspace size" << std::endl;
{{test_call_statement}};
// Allocate workspace if neccessary
if (workspace_size > 0) {
workspace_data.reset(workspace_size);
std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl;
}
std::cout << "Calling Kernel as {{test_call_statement}};" << std::endl;
workspace_size_ptr = nullptr;
for (int i=0; i<repetitions; i++) {
{{test_call_statement}};
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "Device synchronize failed with error "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return 0;
}
int main(int argc, char** argv) {
// warmup
run_standalone(1, 2);
// repeat
return run_standalone(2, 10);
}
#endif
""" # noqa: B950
class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
"""
CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels
including those which allow flexible fusions with epilogues.
"""
def __init__(
self,
input_nodes: List[Buffer],
layout: Layout,
alpha: float,
beta: float,
input_reorder: Optional[List[int]] = None,
) -> None:
"""
Args:
input_nodes (List[Buffer]): List of input nodes of the GEMM kernel.
layout (Layout): Layout type of the resulting output node.
alpha (float): The scaling factor for the product of the inputs in the GEMM operation.
beta (float): The scaling factor applied to the output matrix.
input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided,
no reordering is performed. Defaults to None.
"""
super().__init__("cutlass_gemm", input_nodes, layout, input_reorder)
self.alpha = alpha
self.beta = beta
assert len(input_nodes) == 2 or len(input_nodes) == 3
assert self._are_inputs_layout_compatible(
[node.get_layout() for node in input_nodes]
)
@staticmethod
@abstractmethod
def add_cutlass_gemm_choices(
choices: List[ChoiceCaller],
layout: ir.Layout,
input_nodes: List[Buffer],
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[List[int]] = None,
**extra_kwargs,
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def _get_supported_ops() -> "List[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821
raise NotImplementedError
@staticmethod
@abstractmethod
def _has_tma_epilogue(self) -> bool:
raise NotImplementedError
@abstractmethod
def _get_template(self) -> str:
raise NotImplementedError
@abstractmethod
def _get_template_args(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> Tuple[str, Optional[str]]:
raise NotImplementedError
@abstractmethod
def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool:
raise NotImplementedError
@abstractmethod
def _shape_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
raise NotImplementedError
@abstractmethod
def _alignment_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
raise NotImplementedError
@abstractmethod
def _set_bias_layout_and_alignment(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
raise NotImplementedError
@abstractmethod
def _define_gemm_instance(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> Tuple[str, str]:
raise NotImplementedError
@abstractmethod
def _get_extra_inputs_and_names(
self,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]:
raise NotImplementedError
def _add_cutlass_gemm_choices(
self,
choices: List[ChoiceCaller],
layout: ir.Layout,
input_nodes: List[Buffer],
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[List[int]] = None,
**extra_kwargs,
) -> None:
"""
Adds Cutlass GEMM configurations choices to the auto-tuning list.
This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it.
Args:
choices (list): The list to which choices are appended.
layout (ir.Layout): The layout configuration.
input_nodes (list): The list of input nodes.
alpha (float,int): Scaling factor, defaults to 1.
beta (float,int): Offset, defaults to 0.
input_reorder (list, optional): Order of the inputs, defaults to None.
**extra_kwargs: Additional keyword arguments.
"""
ops = self.gen_ops()
for name, op in ops:
self.maybe_append_choice(
choices,
description=name,
op=op,
)
if len(ops) == 0:
input_layouts = [node.get_layout() for node in input_nodes]
input_strides = [node.get_stride() for node in input_nodes]
output_layout = layout
warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950
log.warning(warning_msg)
log.debug(
"Added %d Cutlass gemm configs.",
len(ops),
)
def header(self) -> IndentedBuffer:
"""
Returns a buffer containing CUDA C++ code for the header section of the CUTLASS GEMM template.
This section primarily includes the necessary header files.
Returns:
IndentedBuffer: An instance of IndentedBuffer that contains the generated CUDA C++ header code.
"""
res = super().header()
res.splice(
"""
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/device/gemm_sparse.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
"""
)
if inductor_cuda_config.generate_test_runner and not is_dynamic(
*self.input_nodes, self.output_node
):
res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES)
return res
@staticmethod
def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821
"""
Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value
(RowMajor, ColumnMajor, or None if no matching value is found ).
Args:
torch_layout (ir.Layout): The layout that needs to be looked up.
Returns:
cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching
value is found.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1):
return cutlass_lib.LayoutType.RowMajor
elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1):
return cutlass_lib.LayoutType.ColumnMajor
else:
return None
@staticmethod
def flip_cutlass_layout(
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821
"""Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor
to ColumnMajor or vice versa"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
return cutlass_lib.LayoutType.ColumnMajor
else:
return cutlass_lib.LayoutType.RowMajor
@staticmethod
def layout_match(
torch_layout: ir.Layout,
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821
) -> bool:
"""Helper Method: Determines whether a given torch layout matches a given Cutlass layout"""
return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout
@staticmethod
def set_alignment(torch_layout, op_element) -> bool:
"""
Helper method to update the alignment of a given CUTLASS GEMM op operand's element.
This method modifies the alignment of the given Cutlass GEMM op operand's element to match the
layout of the corresponding ir.Buffer node.
Args:
torch_layout: The layout of the corresponding ir.Buffer node.
op_element: The Cutlass GEMM op operand's element whose alignment is to be updated.
Returns:
bool: True if the alignment was successfully updated, False otherwise.
"""
alignment = cutlass_utils.get_max_alignment(torch_layout)
cuda_arch = cutlass_utils.get_cuda_arch()
if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment:
return False
else:
op_element.alignment = alignment
return True
@staticmethod
def should_swap_XW(
bias: IRNode,
) -> bool:
"""
Helper method to determine whether we should do an explicit transpose by switching the order of the
matmul operands. This might be neccessary when we can't otherwise arrive at the right memory
layout for the given Bias operand.
Note: This method is a workaround for CUDA Errors that seemingly non-deterministically
occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term.
it might make sense to check on newer Cutlass releases whether it makes sense to keep
returning True in certain cases or whether it becomes unneccessary.
"""
# If bias is row major, swap all M and N dimensions
if (
bias is not None
and len(bias.get_stride()) >= 2
and bias.get_stride()[-1] in (0, 1)
):
log.debug("GEMM Layout swapped X and W -> explicit transpose")
return True
return False
@staticmethod
def swap_XW(
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
"""
Swap operands X and W (aka operans A and B) of the GEMM operation. This
requires transposing the operands, which is done by swapping the strides.
Note that we don't change the apparent external layout, just the operand layout.
this is intentional.
"""
new_op = copy.deepcopy(op)
new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout)
new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout)
new_op.A, new_op.B = new_op.B, new_op.A
new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout)
new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout)
return new_op
def fix_op_layout(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
X: Buffer,
W: Buffer,
Bias: Optional[Buffer],
Y: Union[Buffer, ReinterpretView],
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
# This is a workaround to deal with cases where the input layouts have changed
# between autotuning and rendering. This happens if the inputs layout
# are FlexibleLayout instances. In this case, we need to update the
# op's input layouts. It is a hack, because now the op
# we benchmarked is not the same as the op we render,
# but there is no simple way to fix this in the autotuner, since that would
# potentially disable other optimizations.
a_layout = X.get_layout()
b_layout = W.get_layout()
c_layout = Bias.get_layout() if Bias is not None else None
d_layout = copy.deepcopy(Y.get_layout())
match_list = [
CUTLASSGemmTemplate.layout_match(buf.get_layout(), op_layout)
for buf, op_layout in zip(
(X, W, Bias, Y),
(op.A.layout, op.B.layout, op.C.layout, op.D.layout),
)
if buf is not None
]
all_match = all(match_list)
if all_match:
return op
log.warning(
f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950
)
new_op = copy.deepcopy(op)
if a_layout is not None:
new_op.A.layout = CUTLASSGemmTemplate.cutlass_layout(a_layout)
if b_layout is not None:
new_op.B.layout = CUTLASSGemmTemplate.cutlass_layout(b_layout)
if c_layout is not None:
new_op.C.layout = CUTLASSGemmTemplate.cutlass_layout(c_layout)
new_op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(c_layout.dtype)
if d_layout is not None:
new_op.D.layout = CUTLASSGemmTemplate.cutlass_layout(d_layout)
return new_op
def filter_op(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821
"""
Helper method:
Determines whether a given Cutlass GEMM op definition is suitable for the current
input / output of the operation that this template is supposed to implement.
Takes memory layout, dtype and support for EVT operations into account,
and filters potentially problematic ops.
Returns None if the op is not suitable, otherwise returns the op to be used, which might
have been mutated.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
# Skip simt kernels
if (
op.tile_description.math_instruction.opcode_class
== cutlass_lib.OpcodeClass.Simt
):
return None
if op.gemm_kind not in self._get_supported_ops():
return None
X = self.input_nodes[0]
W = self.input_nodes[1]
# Filter ops according to the shape match.
if not self._shape_match(op):
return None
# Filter ops by dtypes.
accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype(
[X.get_dtype(), W.get_dtype()],
)
if not (
cutlass_utils.dtype_match(X.get_dtype(), op.A.element)
and cutlass_utils.dtype_match(W.get_dtype(), op.B.element)
and cutlass_utils.dtype_match(
self.output_node.get_layout().dtype, op.C.element
)
and cutlass_utils.dtype_match(
accumulator_torch_dtype, op.accumulator_type()
)
):
return None
# Filter ops by input layouts.
if not (
self.layout_match(X.get_layout(), op.A.layout)
and self.layout_match(W.get_layout(), op.B.layout)
):
return None
# Filter ops by alignment.
if not self._alignment_match(op):
return None
# Update op.
op = copy.deepcopy(op)
# Set output layout.
op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout())
# Filter ops by alignments and set alignments.
if not (
self.set_alignment(X.get_layout(), op.A)
and self.set_alignment(W.get_layout(), op.B)
and self.set_alignment(self.output_node.get_layout(), op.D)
):
return None
# Set epilogue.
# TODO: update epilogue functor according to epilogues.
op.element_epilogue = op.accumulator_type()
if inductor_cuda_config.cutlass_op_allowlist_regex is not None:
if not re.search(
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
):
return None
if inductor_cuda_config.cutlass_op_denylist_regex is not None:
if re.search(
inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name()
):
return None
# Set bias layout and alignment.
if not self._set_bias_layout_and_alignment(op):
return None
return op
def gen_ops(self) -> "List[Tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821
"""
Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent.
The matching is carried out with respect to the input and output specifications of the operation.
No function arguments.
Returns:
List[Tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation)
tuples that are compatible with the operation requirements of this template.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
res: Dict[str, cutlass_gemm_op.GemmOperation] = {}
for op_dict in ops.values():
for op_list in op_dict.values():
for op in op_list:
assert isinstance(op, cutlass_gemm_op.GemmOperation)
filter_res = self.filter_op(op)
if (
filter_res is not None
and res.get(filter_res.configuration_name(), None) is None
):
res[filter_res.configuration_name()] = filter_res
log.debug("Got cutlass configs: total number of ops: %d, ", len(res))
return list(res.items())[: inductor_cuda_config.cutlass_max_profiling_configs]
def gemm_mode(self) -> str:
"""
Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements
a batched GEMM or a simple GEMM without batch dimension.
Returns:
str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions,
"cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise
"cutlass::gemm::GemmUniversalMode::kGemm" is returned.
"""
sizes = self.output_node.get_size()
if len(sizes) > 2:
return "cutlass::gemm::GemmUniversalMode::kBatched"
else:
return "cutlass::gemm::GemmUniversalMode::kGemm"
def render( # type: ignore[override]
self,
kernel: CUDATemplateKernel,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
template_buffer_node: Optional[CUDATemplateBuffer] = None,
**kwargs,
) -> str:
"""
The primary entry point for the code rendering process used in this template.
Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement,
including potentially fused epilogues.
Args:
kernel (CUDATemplateKernel): The kernel to be rendered.
op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the
input and output definitions as well as a possible epilogue. Defaults to None.
**kwargs: Additional keyword arguments. Currently unused.
Returns:
str: Cutlass based CUDA C++ code fragment as a string, to be used by the current
CUDATemplateKernel or autotuning code.
Note:
All inputs and their corresponding buffer addresses and names take precedence over previously
passed inputs to the template at construction time. However, they should be layout compatible.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
assert isinstance(
op, cutlass_gemm_op.GemmOperation
), "op argument is required and has to be an instance of GemmOperation"
assert len(self.input_nodes) >= 2 and self.output_node is not None
X, W = self.input_nodes[0], self.input_nodes[1]
if not isinstance(X.layout, FixedLayout):
raise NotImplementedError("X.layout is not fixed")
if not isinstance(W.layout, FixedLayout):
raise NotImplementedError("W.layout is not fixed")
Y = self.output_node
if template_buffer_node is not None:
Y = template_buffer_node
Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op)
# Define Kernel call signature
# Important: This step also populates Kernel name to node mapping data structures,
# which are required further below ( for example by CutlassEVTEpilogueArgumentFormatter and
# the template renderer )
inputs = [X, W, Bias, *extra_inputs]
names = ["X", "W", "Bias", *extra_names] + ["Y"]
names_str = ",".join(names)
if self.input_reorder is not None:
input_reorder = self.input_reorder
else:
input_reorder = None
kernel_call_signature = kernel.def_kernel(
inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type]
)
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
# The layouts might have changed between autotuning and this call if they were FlexibleLayout
# we need to adapt, which might lead to suboptimal performance.
op = self.fix_op_layout(op, X, W, Bias, Y)
# to make op mutable without affecting others
op = copy.deepcopy(op)
if Bias is not None:
assert Bias.get_layout().dtype == X.get_layout().dtype
# This might have been set to void during filtering, when the assumption was still that there's no C
# operand
op.C.element = op.A.element
argument_template, epilogue_template = self._get_template_args(op)
should_swap_xw: bool = False
epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}"
if Bias is not None and self._has_tma_epilogue(op):
if (
op.epilogue_schedule
!= cutlass_lib.EpilogueScheduleType.EpilogueTransposed
and self.should_swap_XW(Bias)
):
# TMA epilogue requires bias vector in column major to get best perf.
op = self.swap_XW(op)
should_swap_xw = True
instance_definition, instance_type = self._define_gemm_instance(op)
options = dict(
alpha=self.alpha,
beta=self.beta,
X=X,
W=W,
Y=Y,
kernel_call_signature=kernel_call_signature,
Bias=Bias,
epilogue_template=epilogue_template,
argument_template=argument_template,
should_swap_xw=should_swap_xw,
template=self,
kernel=kernel,
instance_definition=instance_definition,
instance_type=instance_type,
input_reorder=self.input_reorder,
epilogue_args=epilogue_args,
test_call_statement=test_call_statement,
)
options.update(dict(zip(extra_names, extra_inputs)))
res = self._template_from_string(self._get_template()).render(**options)
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
test_runner_code = self._template_from_string(
GEMM_STANDALONE_RUNNER_TEMPLATE
).render(**options)
res += "\n\n" + test_runner_code
return res
def test_call_statement(
self,
kernel,
input_nodes,
names_str: str = "",
) -> str:
"""
Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone
test runner that might also be generated along with the rest of the code, if the corresponding config is
enabled.
Returns a C++ statement that calls the GEMM operation with the correct arguments.
"""
_, __, arg_types = kernel.args.cpp_argdefs()
arg_names = [name.strip() for name in names_str.strip().split(",")]
if input_nodes[2] is None:
del arg_names[2]
arguments = [
f"(({arg_type}){arg_name}_data.get())"
for arg_type, arg_name in zip(arg_types, arg_names)
]
return f"{kernel.kernel_name}({', '.join(arguments)}, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);"
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
def __init__(
self,
input_nodes: List[Buffer],
layout: Layout,
alpha: float,
beta: float,
input_reorder: Optional[List[int]] = None,
):
super().__init__(input_nodes, layout, alpha, beta, input_reorder)
@staticmethod
def add_cutlass_gemm_choices(
choices: List[ChoiceCaller],
layout: ir.Layout,
input_nodes: List[Buffer],
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[List[int]] = None,
**extra_kwargs,
) -> None:
template = CUTLASS3xGemmTemplate(
input_nodes, layout, alpha, beta, input_reorder
)
template._add_cutlass_gemm_choices(
choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs
)
@staticmethod
def _get_supported_ops() -> "List[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821
import cutlass_library.library as cutlass_lib
return [cutlass_lib.GemmKind.Universal3x]
def _get_template(self) -> str:
return GEMM_TEMPLATE_CUTLASS_3X
def _get_template_args(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> Tuple[str, Optional[str]]:
return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE)
@staticmethod
def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined]
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821
) -> bool: # type: ignore[name-defined]
"""Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib
result = False
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1]
result = epilogue_schedule_str.lower().startswith("tma")
return result
def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool:
"""
Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM).
This function checks compatibility of A, B, and possibly C operand layouts for
a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'.
It verifies requirements such as matching data types, minimum rank, and suitability
for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`,
`addmm`, `bmm`, `baddbmm`, etc.
Args:
layouts (List[Layout]): List containing 2 or 3 Layout objects representing
the input matrices A, B, and possibly C.
Returns:
bool: True if layouts are GEMM compatible, otherwise False.
"""
assert len(layouts) == 2 or len(layouts) == 3
# Check if A and B are compatible
A_layout, B_layout = layouts[:2]
if len(A_layout.size) < 1:
return False
if len(B_layout.size) < 1:
return False
A_size = list(V.graph.sizevars.size_hints(A_layout.size))
B_size = list(V.graph.sizevars.size_hints(B_layout.size))
if len(A_size) < 2:
A_size.insert(0, 1)
if len(B_size) < 2:
A_size.insert(1, 1)
# Are batch dims broadcastable?
while len(A_size) < len(B_size):
A_size.insert(0, 1)
while len(B_size) < len(A_size):
B_size.insert(0, 1)
K = max(A_size[-1], B_size[-2])
M = A_size[-2]
N = B_size[-1]
if K != A_size[-1] and A_size[-1] != 1:
return False
if K != B_size[-2] and B_size[-1] != 1:
return False
# check batch dim broadcastable
for i in range(len(A_size) - 2):
if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1:
return False
if len(layouts) == 3:
C_layout = layouts[2]
C_size = [int(i) for i in C_layout.size]
while len(C_size) < len(A_size):
C_size.insert(0, 1)
# check batch dims
for i in range(len(A_size) - 2):
bd = max(A_size[i], B_size[i])
if bd != C_size[i] and C_size[i] != 1:
return False
if len(C_size) > len(A_size):
# This may happen if the last elements of C are contiguous and
# their multiplied size equals the last dim size of B
if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1:
return False
remaining_size = 1
for i in range(len(A_size) - 1, len(C_size)):
remaining_size *= C_size[i]
if N != remaining_size and remaining_size != 1:
return False
return True
assert len(C_size) == len(A_size)
if M != C_size[-2] and C_size[-2] != 1:
return False
if N != C_size[-1] and C_size[-1] != 1:
return False
return True
def _shape_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
X, W = self.input_nodes[0], self.input_nodes[1]
return X.get_size()[1] == W.get_size()[0]
def _alignment_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
return True
def _set_bias_layout_and_alignment(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
import cutlass_library.library as cutlass_lib
if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None:
Bias = self.input_nodes[2]
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
if bias_layout != op.D.layout:
# For cutlass2, bias and output layout must match
return False
else:
op.C.layout = bias_layout
if not self.set_alignment(Bias.get_layout(), op.C):
return False
else:
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
op.C.element = cutlass_lib.DataType.void
else:
op.C.layout = op.D.layout
return True
def _define_gemm_instance(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> Tuple[str, str]:
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply
forms a core part of a number of scientific applications, so this efficient and adaptable implementation is
crucial.
Args:
op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering.
Returns:
Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++
code (render) and the second part is the string that specifies the operation type.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
if not hasattr(op, "epilogue_functor") or not isinstance(
op.epilogue_functor, enum.Enum
):
op = copy.deepcopy(op)
op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination
op_def = emitter.emit(op)
pattern = re.compile(r"\s*struct\s(.*?)\s:")
decl = [line for line in op_def.split("\n") if "struct " in line][-1]
match = pattern.match(decl)
if match is None:
raise RuntimeError("Invalid Gemm config: \n" + op_def)
op_type = match.groups()[0]
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n"
op_type = f"{op_type}_device_type"
return op_def, op_type
def _get_extra_inputs_and_names(
self,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]:
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
inputs: List[Optional[Buffer]] = []
names: List[str] = []
return (Bias, inputs, names)
def render_gemm_arguments(
self,
argument_template: str,
epilogue_template: str,
should_swap_xw: bool,
X: IRNode,
W: IRNode,
Bias: IRNode,
Y: IRNode,
alpha: float,
beta: float,
kernel: CUDATemplateKernel,
epilogue_args,
) -> str:
"""
Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation.
Args:
argument_template (str): Template for the GEMM operation arguments.
epilogue_template (str): Template for the epilogue arguments.
should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit
transpose operation to X and W.
X (IRNode): The X input tensor.
W (IRNode): The W input tensor.
Bias (IRNode): The bias tensor.
Y (IRNode): The output tensor.
alpha (float): Scaling factor for the product of the inputs.
beta (float): Scaling factor for the output tensor.
kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
epilogue_args (any): Additional arguments for the epilogue state.
Returns:
str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation.
Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
before the function call.
"""
options = dict(
alpha=alpha,
beta=beta,
X=X,
W=W,
Y=Y,
Bias=Bias,
template=self,
kernel=kernel,
M="M",
N="N",
epilogue_args=epilogue_args,
)
assert epilogue_template is not None
if should_swap_xw:
# Swap
def clone_with_transposed_stride(node: IRNode) -> IRNode:
old_layout = node.get_layout()
new_stride = list(old_layout.stride) # type: ignore[union-attr]
new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2]
assert old_layout.device is not None
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
list(old_layout.size), # type: ignore[union-attr]
new_stride,
old_layout.offset, # type: ignore[union-attr]
)
return Buffer(name=node.get_name(), layout=new_layout)
new_X = clone_with_transposed_stride(X)
new_W = clone_with_transposed_stride(W)
new_Bias = clone_with_transposed_stride(Bias)
new_Y = clone_with_transposed_stride(Y)
options["X"], options["W"], options["Bias"], options["Y"] = (
new_W,
new_X,
new_Bias,
new_Y,
)
options["M"], options["N"] = "N", "M"
epilogue_arguments = self._template_from_string(epilogue_template).render(
**options
)
arguments = self._template_from_string(argument_template).render(
epilogue_arguments=epilogue_arguments, **options
)
return arguments
class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
def __init__(
self,
input_nodes: List[Buffer],
layout: Layout,
alpha: float,
beta: float,
input_reorder: Optional[List[int]] = None,
):
super().__init__(input_nodes, layout, alpha, beta, input_reorder)
@staticmethod
def add_cutlass_gemm_choices(
choices: List[ChoiceCaller],
layout: ir.Layout,
input_nodes: List[Buffer],
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[List[int]] = None,
**extra_kwargs,
) -> None:
template = CUTLASS2xGemmTemplate(
input_nodes, layout, alpha, beta, input_reorder
)
template._add_cutlass_gemm_choices(
choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs
)
@staticmethod
def _get_supported_ops() -> "List[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821
import cutlass_library.library as cutlass_lib
return [cutlass_lib.GemmKind.Universal, cutlass_lib.GemmKind.Sparse]
@staticmethod
def _has_tma_epilogue(self) -> bool:
return False
def _get_template(self) -> str:
return GEMM_TEMPLATE_CUTLASS_2X
def _get_template_args(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> Tuple[str, Optional[str]]:
import cutlass_library.library as cutlass_lib
if op.gemm_kind == cutlass_lib.GemmKind.Sparse:
return (GEMM_ARGS_SPARSE_CUTLASS_2X, None)
return (GEMM_ARGS_CUTLASS_2X, None)
def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool:
"""
Evaluates whether input layouts are compatible for set of operations supported by this class.
Args:
layouts (List[Layout]): List containing Layout objects representing
the input matrices.
Returns:
bool: True if layouts are GEMM compatible, otherwise False.
"""
assert len(layouts) == 2 or len(layouts) == 3
# Check if A and B are compatible
A_layout, B_layout = layouts[:2]
if len(A_layout.size) != 2:
return False
if len(A_layout.size) != 2:
return False
A_size = [int(i) for i in A_layout.size]
B_size = [int(i) for i in B_layout.size]
K = max(A_size[-1], B_size[-2])
M = A_size[-2]
N = B_size[-1]
if K != A_size[-1] and K != 2 * A_size[-2]:
return False
if K != B_size[-2]:
return False
return True
def _shape_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
import cutlass_library.library as cutlass_lib
X, W = self.input_nodes[0], self.input_nodes[1]
if op.gemm_kind == cutlass_lib.GemmKind.Sparse:
return X.get_size()[1] * 2 == W.get_size()[0]
return X.get_size()[1] == W.get_size()[0]
def _alignment_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
import cutlass_library.library as cutlass_lib
if op.gemm_kind != cutlass_lib.GemmKind.Sparse:
return True
# SparseGemm in CUTLASS has specific alignment check that for
# small k could make some of the choices throw kMisalignedOperand
# CUTLASS error when run, see:
# https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950
# So, let's skip these choices if that would be the case.
X = self.input_nodes[0]
return (X.get_size()[1] * 2) % op.tile_description.tile_shape[2] == 0
def _set_bias_layout_and_alignment(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
import cutlass_library.library as cutlass_lib
if op.gemm_kind == cutlass_lib.GemmKind.Sparse:
op.C.layout = op.D.layout
return True
if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None:
Bias = self.input_nodes[2]
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
if bias_layout != op.D.layout:
# For cutlass2, bias and output layout must match
return False
if not self.set_alignment(Bias.get_layout(), op.C):
return False
else:
op.C.layout = op.D.layout
return True
def _define_gemm_instance(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> Tuple[str, str]:
"""Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance.
This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply
forms a core part of a number of scientific applications, so this efficient and adaptable implementation is
crucial.
Args:
op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering.
Returns:
Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++
code (render) and the second part is the string that specifies the operation type.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
if op.gemm_kind == cutlass_lib.GemmKind.Sparse:
emitter = cutlass_gemm_op.EmitSparseGemmInstance()
else:
emitter = cutlass_gemm_op.EmitGemmInstance()
op_def = emitter.emit(op)
op_def = op_def.replace(
"cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal"
)
if op.gemm_kind != cutlass_lib.GemmKind.Sparse:
op_def = op_def.replace("false,", "")
pattern = re.compile(r"\s*using\s(.*?)\s=")
decl = op_def.split("\n")[2]
match = pattern.match(decl)
if match is None:
raise RuntimeError("Invalid Gemm config: \n" + op_def)
op_type = match.groups()[0]
return op_def, op_type
def _get_extra_inputs_and_names(
self,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]:
import cutlass_library.library as cutlass_lib
if op.gemm_kind == cutlass_lib.GemmKind.Sparse:
Bias = None
Meta = self.input_nodes[2]
else:
Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2]
Meta = None
inputs = [Meta]
names = ["Meta"]
return (Bias, inputs, names)
def render_gemm_arguments(
self,
instance_type: str,
argument_template: str,
epilogue_template: str,
should_swap_xw: bool,
X: IRNode,
W: IRNode,
Bias: IRNode,
Meta: IRNode,
Y: IRNode,
alpha: float,
beta: float,
kernel: CUDATemplateKernel,
epilogue_args,
) -> str:
"""
Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation.
Args:
instance_type (str): GEMM instance type.
argument_template (str): Template for the GEMM operation arguments.
epilogue_template (str): Template for the epilogue arguments.
should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit
transpose operation to X and W.
X (IRNode): The X input tensor.
W (IRNode): The W input tensor.
Bias (IRNode): The bias tensor.
Meta (IRNode): The meta tensor.
Y (IRNode): The output tensor.
alpha (float): Scaling factor for the product of the inputs.
beta (float): Scaling factor for the output tensor.
kernel (CUDATemplateKernel): CUDA Template kernel for the operation.
epilogue_args (any): Additional arguments for the epilogue state.
Returns:
str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation.
Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
before the function call.
"""
options = dict(
instance_type=instance_type,
alpha=alpha,
beta=beta,
X=X,
W=W,
Y=Y,
Bias=Bias,
Meta=Meta,
template=self,
kernel=kernel,
M="M",
N="N",
epilogue_args=epilogue_args,
)
if epilogue_template is None:
arguments = self._template_from_string(argument_template).render(
split_k=1, **options
)
return arguments
epilogue_arguments = self._template_from_string(epilogue_template).render(
**options
)
arguments = self._template_from_string(argument_template).render(
epilogue_arguments=epilogue_arguments, **options
)
return arguments
|