1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107
|
# mypy: allow-untyped-defs
from __future__ import annotations
import builtins
import copy
import functools
import hashlib
import inspect
import logging
import math
import operator
import os
import os.path
import re
import sys
import threading
import time
from typing import Any, Container, Dict, List, Optional, Set, Tuple
import torch
from ..triton_bundler import TritonBundler
from .autotune_cache import AutotuneCache
from .benchmarking import benchmarker
from .coordinate_descent_tuner import CoordescTuner
from .hints import (
_NUM_THREADS_PER_WARP,
AutotuneHint,
DeviceProperties,
HeuristicType,
ReductionHint,
TileHint,
TRITON_MAX_BLOCK,
TRITON_MAX_RSPLIT,
)
from .runtime_utils import (
ceildiv,
conditional_product,
create_bandwidth_info_str,
dynamo_timed,
get_first_attr,
get_max_y_grid,
get_num_bytes,
next_power_of_2,
triton_cache_dir,
triton_config_to_hashable,
triton_hash_to_path_key,
validate_triton_config,
)
try:
import triton
except ImportError:
triton = None
if triton is not None:
from triton import Config
from triton.compiler import CompiledKernel
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface
from . import triton_helpers
try:
from triton.runtime.autotuner import PTXASError
except ImportError:
class PTXASError(Exception): # type: ignore[no-redef]
pass
try:
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None
try:
from triton.backends.compiler import GPUTarget
except ImportError:
GPUTarget = None
else:
from types import ModuleType
class OutOfResources(Exception): # type: ignore[no-redef]
pass
class PTXASError(Exception): # type: ignore[no-redef]
pass
Config = object
KernelInterface = object
ASTSource = None
GPUTarget = None
triton_helpers = ModuleType("triton_helpers")
try:
autograd_profiler = torch.autograd.profiler
except AttributeError: # Compile workers only have a mock version of torch
class autograd_profiler: # type: ignore[no-redef]
_is_profiler_enabled = False
log = logging.getLogger(__name__)
def autotune_hints_to_configs(
hints: Set[AutotuneHint],
size_hints,
block_size: int,
device_props: DeviceProperties,
) -> List[Config]:
"""
AutotuneHints can be attached to the metadata of triton kernels for providing
suggestions about what to try for autotuning. One reason to do this is if there are
some configs that are only useful in specific scenarios, in which case we can avoid
wasting compile time on autotuning unless we know we are in one of those scenarios.
Based on those hints, this function will generate a list of additional autotuning
configs to try.
"""
xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
configs: List[Config] = []
warp_size = device_props.warp_size
# CPU target has no concept of "warp"
if warp_size is None:
warp_size = 32
for hint in hints:
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
if len(size_hints) == 1:
xyz_options = ((block_size // 4, None, None),)
elif len(size_hints) == 2:
xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
elif len(size_hints) == 3:
xyz_options = (
(block_size // 4, 1, 1),
(1, block_size // 4, 1),
(1, 1, block_size // 4),
)
configs.extend(
triton_config(
size_hints,
*xyz,
num_elements_per_warp=(
device_props.warp_size if device_props.warp_size else 32
),
)
for xyz in xyz_options
)
return configs
def disable_pointwise_autotuning(inductor_meta):
# Autotuning can give different benchmarking results from run to run, and
# therefore we disable autotuning when use_deterministic flag is on.
if inductor_meta.get("are_deterministic_algorithms_enabled"):
return True
return not inductor_meta.get("autotune_pointwise", True)
def _dump_launch_params(args, kwargs, launcher, kernel_name):
call_args = []
call_kwargs = {}
for arg in args:
if isinstance(arg, (int, bool)):
call_args.append(str(arg))
else:
call_args.append("T")
for k, v in kwargs.items():
if isinstance(arg, (int, bool)):
call_kwargs[k] = v
else:
call_kwargs[k] = v
for k, v in launcher.config.kwargs.items():
call_kwargs[k] = v
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
args_str = ""
args_str += ", ".join(call_args)
for k, v in call_kwargs.items():
args_str += f", {k}={v}"
abs_path = os.path.abspath(sys.argv[0])
with open(f"{abs_path}.launch_params", "a") as f:
f.write(f"{kernel_name} | {args_str}\n")
class CachingAutotuner(KernelInterface):
"""
Simplified version of Triton autotuner that has no invalidation
key and caches the best config to disk to improve cold start times.
Unlike the main triton Autotuner, this version can precompile all
configs, and does not rely on the Triton JIT.
"""
def __init__(
self,
fn,
triton_meta, # passed directly to triton
configs,
save_cache_hook,
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
optimize_mem,
heuristic_type,
size_hints=None,
inductor_meta=None, # metadata not relevant to triton
custom_kernel=False, # whether the kernel is inductor-generated or custom
filename: Optional[str] = None,
reset_to_zero_arg_names: Optional[List[str]] = None,
):
super().__init__()
assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
# makes sure there are no pre-hooks on any of the triton configs
for cfg in configs:
validate_triton_config(cfg)
self.fn = fn
self.device_props: DeviceProperties = triton_meta["device"]
self.triton_meta = {
**triton_meta,
"device": self.device_props.index,
"device_type": self.device_props.type,
}
self.inductor_meta = {} if inductor_meta is None else inductor_meta
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.reset_to_zero_arg_names = (
[] if reset_to_zero_arg_names is None else reset_to_zero_arg_names
)
self.optimize_mem = optimize_mem
self.configs = configs
self.heuristic_type = heuristic_type
self.custom_kernel = custom_kernel
self.cuda_kernel_saved = False
if log.isEnabledFor(logging.DEBUG):
log.debug(
"CachingAutotuner gets %d configs for %s",
len(self.configs),
self.fn.__name__,
)
for c in self.configs:
log.debug(c)
self.launchers = [] # type: ignore[var-annotated]
self.lock = threading.Lock()
if os.getenv("TRITON_CACHE_DIR") is None:
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
self.triton_meta.get("device", 0)
)
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
self.size_hints = size_hints
self.coordesc_tuner = CoordescTuner(
is_mm=False,
name=self.fn.__name__,
size_hints=size_hints,
inductor_meta=self.inductor_meta,
)
self.filename = filename
# used for profiling
self.kernel_hash: str = ""
# Kernels are stored in the codecache with the filename as a hash of the code.
# We rely on this to obtain the kernel hash
if self.filename is not None:
base_name = os.path.basename(self.filename)
if ".py" in base_name:
self.kernel_hash = os.path.splitext(base_name)[0]
self.precompile_time_taken_ns = 0
self.autotune_time_taken_ns = 0
# Dumps the launch configs after autotuning.
self.dump_launch_params = (
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
)
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
def precompile(self, warm_cache_only=False):
with self.lock:
if self.launchers:
return
self.launchers = []
compiled_binaries = []
if not self.configs:
raise RuntimeError("No triton configs are available")
for c in self.configs:
try:
compiled_binary, launcher = self._precompile_config(
c, warm_cache_only
)
except (OutOfResources, PTXASError) as e:
if len(self.configs) == 1:
# There are no valid Triton configs
raise e
# Skip the config if we run out of
# resources or into a ptxas error
continue
self.launchers.append(launcher)
compiled_binaries.append(compiled_binary)
if len(self.launchers) == 0:
raise RuntimeError(
"No valid triton configs. Report a fatal compilation error"
)
seen_configs = set(self.configs)
device_prop = self.device_props
warp_size = device_prop.warp_size
# CPU target has no concept of "warp"
if warp_size is None:
warp_size = 32
if (
self.inductor_meta.get("dynamic_scale_rblock", True)
and not self.inductor_meta.get("persistent_reduction")
and self.heuristic_type == HeuristicType.REDUCTION
and self.size_hints is not None
# Disable for Intel as Triton is not ready to return n_regs for a compiled_binary.
and device_prop.type in ["cuda", "hip"]
and device_prop.major
and (device_prop.major >= 8 or torch.version.hip)
and device_prop.regs_per_multiprocessor is not None
):
assert device_prop.regs_per_multiprocessor
assert device_prop.max_threads_per_multi_processor
assert device_prop.multi_processor_count
for triton_config, compiled_binary in zip(
self.configs, compiled_binaries
):
assert len(self.size_hints) == 2
xblock = triton_config.kwargs.get("XBLOCK", 1)
rblock = triton_config.kwargs["RBLOCK"]
total_block = (self.size_hints["x"] + xblock - 1) // xblock
nreg = getattr(compiled_binary, "n_regs", None)
if nreg is None:
continue
# make sure rblock is not too small
if rblock <= 64:
continue
# each SM of A100 has 65536 32-bit registers. To maximize
# the theoretical occupancy, we need run 2048 threads on each
# SM. So each thread should use no more than 65536 / 2048
# = 32 registers. In cases where occupancy matters, and each
# thread uses too many registers, reduce RBLOCK to reduce
# the register usage.
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
# from PLBartForCausalLM, latency improve from
# 7.795ms to 4.883ms.
#
if (
nreg
<= device_prop.regs_per_multiprocessor
// device_prop.max_threads_per_multi_processor
):
continue
nreg_per_warp = nreg * warp_size
nreg_per_block = nreg_per_warp * triton_config.num_warps
# Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
# The formula below is a tighter upper bound since we have the assumption that
# nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
# due to the if condition above and:
# regs_per_multiprocessor / nreg_per_block
# = regs_per_multiprocessor / (nreg * 32 * num_warps)
# < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
# = max_threads_per_multi_processor / (32 * num_warps)
# Using a tigher upper bound can reveal more optimization opportunities.
max_blocks_per_sm = max(
device_prop.regs_per_multiprocessor // nreg_per_block, 1
)
if (
total_block
<= max_blocks_per_sm * device_prop.multi_processor_count
):
# no need to improve occupancy
continue
new_config = copy.deepcopy(triton_config)
new_config.kwargs["RBLOCK"] = rblock // 2
if new_config in seen_configs:
continue
seen_configs.add(new_config)
log.debug(
"Dynamically scale down RBLOCK from TritonConfig(%s) and get a new TritonConfig(%s)",
triton_config,
new_config,
)
self.launchers.append(
self._precompile_config(new_config, warm_cache_only)[1]
)
self.configs = None
def get_device_interface(self):
# this code cannot run in compile workers, because it imports from torch
from torch._dynamo.device_interface import get_interface_for_device
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
def _precompile_config(self, cfg: Config, warm_cache_only: bool):
"""Ahead of time compile a given autotuner config."""
compile_meta = copy.deepcopy(self.triton_meta)
for k, v in cfg.kwargs.items():
if self.device_props.type == "hip":
if k == "matrix_instr_nonkdim":
compile_meta["matrix_instr_nonkdim"] = v
continue
if k == "waves_per_eu":
compile_meta["waves_per_eu"] = v
continue
if k == "kpack":
compile_meta["kpack"] = v
continue
compile_meta["constants"][k] = v
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
compile_meta["debug"] = self.inductor_meta.get(
"assert_indirect_indexing", True
) and not self.inductor_meta.get("is_hip", False)
# device type will be "hip" rather than "cuda" here
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
if self.device_props.type == "cpu":
triton_helpers.set_driver_to_cpu()
else:
triton_helpers.set_driver_to_gpu()
if ASTSource:
compile_args = (
ASTSource(
self.fn,
compile_meta["signature"],
compile_meta["constants"],
compile_meta["configs"][0],
),
)
cc_str = str(compile_meta["cc"])
if "gfx10" in cc_str or "gfx11" in cc_str:
rocm_warp_size = 32
else:
rocm_warp_size = 64
if GPUTarget:
target = GPUTarget(
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size if torch.version.hip else 32,
)
else:
target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size,
]
)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if self.device_props.type == "hip":
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
if "matrix_instr_nonkdim" in compile_meta:
options["matrix_instr_nonkdim"] = compile_meta[
"matrix_instr_nonkdim"
]
compile_kwargs = {
"target": target,
"options": options,
}
else:
compile_args = (self.fn,)
compile_kwargs = compile_meta
if warm_cache_only:
binary = triton.compile(*compile_args, **compile_kwargs)
launcher = None
TritonBundler.put(
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
)
return binary, launcher
# importing from torch is safe now that precompile has returned
from torch._dynamo.device_interface import DeviceGuard
device_interface = self.get_device_interface()
# load binary to the correct device
with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined]
# need to initialize context
device_interface.synchronize(device_interface.current_device())
try:
binary = triton.compile(*compile_args, **compile_kwargs)
except Exception:
log.exception(
"Triton compilation failed: %s\n%s\nmetadata: %s",
self.inductor_meta.get("kernel_name", "triton_"),
self.fn.src,
compile_meta,
)
raise
binary._init_handles()
"""
https://github.com/pytorch/pytorch/issues/115344
self.fn.constexprs doesn't properly deal with None args, so when we filter out
an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well.
We also don't want to modify self.fn.
We know that we removed something from the signature if:
1. It's in compile_meta["constants"]
2. It isn't a constant we already know about
Note: The value of interest has already been added to compile_meta['constants'],
so we use self.fn.constexprs instead.
3. It isn't in the compile_meta signature
"""
known_constants = {
arg for i, arg in enumerate(self.fn.arg_names) if i in self.fn.constexprs
}
none_args = {
k
for k, v in compile_meta["constants"].items()
if v is None and k not in known_constants
}
none_args = none_args.difference(set(compile_meta["signature"].keys()))
call_args = [
arg
for i, arg in enumerate(self.fn.arg_names)
if i not in self.fn.constexprs and arg not in none_args
]
def_args = [
name
for name in self.fn.arg_names
if name not in cfg.kwargs and name not in none_args
]
binary_shared = (
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
)
scope = {
"grid_meta": cfg.kwargs,
"bin": binary,
"launch_enter_hook": CompiledKernel.launch_enter_hook,
"launch_exit_hook": CompiledKernel.launch_exit_hook,
"metadata": (
binary.packed_metadata
if hasattr(binary, "packed_metadata")
else binary.metadata
),
"shared": binary_shared,
}
scope["num_warps"] = (
binary.num_warps
if hasattr(binary, "num_warps")
else binary.metadata.num_warps
)
scope["cta_args"] = (
(binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
if hasattr(binary, "num_ctas")
else (
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
if hasattr(binary, "metadata")
else ()
)
)
scope["function"] = get_first_attr(binary, "function", "cu_function")
def get_launch_args_without_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args before CompiledKernel.launch_metadata is added.
"""
return (
grid_0,
grid_1,
grid_2,
num_warps,
*cta_args,
shared,
stream,
function,
launch_enter_hook,
launch_exit_hook,
metadata,
)
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
# `bin.launch_metadata` is relatively expensive, and returns None unless a
# `launch_enter_hook` is installed. So if we don't have that hook installed,
# we want to burn None in to the launch args with zero overhead.
# See https://github.com/pytorch/pytorch/issues/123597
if binary.launch_enter_hook:
def get_launch_args_with_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args after CompiledKernel.launch_metadata is added
by https://github.com/openai/triton/pull/3492 .
"""
return (
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin.launch_metadata(grid, stream, *args),
launch_enter_hook,
launch_exit_hook,
)
else:
def get_launch_args_with_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args after CompiledKernel.launch_metadata is added
by https://github.com/openai/triton/pull/3492 .
"""
return (
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
None,
launch_enter_hook,
launch_exit_hook,
)
scope["get_launch_args"] = (
get_launch_args_with_kernel_launch_metadata
if hasattr(binary, "launch_metadata")
else get_launch_args_without_kernel_launch_metadata
)
scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
args = {', '.join(call_args)},
launch_args = get_launch_args(
grid, grid_0, grid_1, grid_2, stream, function,
metadata, bin, launch_enter_hook, launch_exit_hook,
num_warps, shared, cta_args, args
)
runner(*launch_args, *args)
return bin
""".lstrip(),
scope,
)
launcher = scope["launcher"]
launcher.config = cfg
launcher.n_regs = getattr(binary, "n_regs", None)
launcher.n_spills = getattr(binary, "n_spills", None)
launcher.shared = binary_shared
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
# store this global variable to avoid the high overhead of reading it when calling run
if launcher.store_cubin:
launcher.fn = self.fn
launcher.bin = binary
TritonBundler.put(
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
)
return binary, launcher
def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
"""Measure the performance of a given launcher"""
# we don't skip configs with spilled registers when auto-tuning custom
# (user-written) Triton kernels, as (i) we don't have any knowledge or
# control over the kernel code; (ii) there is empirical evidence that
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
):
log.debug(
"Skip config %s because of register spilling: %d",
launcher.config,
launcher.n_spills,
)
return float("inf")
device_interface = self.get_device_interface()
stream = device_interface.get_raw_stream(device_interface.current_device())
cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs)
def kernel_call():
cloned_args, cloned_kwargs = self.maybe_clone_args(
cpu_copies, *args, **kwargs
)
# reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs)
launcher(
*cloned_args,
**cloned_kwargs,
grid=grid,
stream=stream,
)
self.restore_args_from_cpu(cpu_copies)
if with_profiler:
from torch._inductor.utils import do_bench_using_profiling
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
if self.device_props.type == "cpu":
return benchmarker.benchmark_cpu(kernel_call)
return benchmarker.benchmark_gpu(kernel_call, rep=40)
def copy_args_to_cpu_if_needed(self, *args, **kwargs):
"""
To support benchmarking in the presence of mutated args, we need to avoid
autotuning contanminating them. We try to pass cloned args to the kernel.
If those clones would increase the peak memory usage, however, we instead
copy to cpu and restore them after each iteratrion. Figure out the args
to be copied and do the copying.
"""
if not self.optimize_mem:
return {}
copies = {}
budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated()
def maybe_copy(name, arg):
if name in self.mutated_arg_names and arg.is_cuda:
nonlocal budget
assert isinstance(arg, torch.Tensor)
size = arg.numel() * arg.element_size()
if size > budget:
cpu_arg = torch.empty_strided(
arg.size(),
arg.stride(),
dtype=arg.dtype,
device="cpu",
pin_memory=True,
)
cpu_arg.copy_(arg, non_blocking=True)
copies[name] = (arg, cpu_arg)
else:
budget -= size
for i, arg in enumerate(args):
maybe_copy(self.fn.arg_names[i], arg)
for name, arg in kwargs.items():
maybe_copy(name, arg)
return copies
def restore_args_from_cpu(self, cpu_copies):
for pair in cpu_copies.values():
arg, cpu_arg = pair
arg.copy_(cpu_arg, non_blocking=True)
def reset_to_zero_args(self, *args, **kwargs):
if not self.reset_to_zero_arg_names:
return
for i, arg in enumerate(args):
if self.fn.arg_names[i] in self.reset_to_zero_arg_names:
assert isinstance(
arg,
torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names"
arg.zero_()
for name, arg in kwargs.items():
if name in self.reset_to_zero_arg_names:
assert isinstance(
arg,
torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names"
arg.zero_()
def maybe_clone_args(
self, exclude: Container[str], *args, **kwargs
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Prepare new args and kwargs by cloning any in-place buffers
(that are not in the provided exclusion list), to avoid autotune
contaminating them. Avoid cloning the other buffers because it
leads to increased memory usage.
"""
from ..compile_fx import clone_preserve_strides
def prepare_arg(name, arg):
if name in self.mutated_arg_names and name not in exclude:
assert isinstance(arg, torch.Tensor)
return clone_preserve_strides(arg)
else:
return arg
cloned_args = [
prepare_arg(self.fn.arg_names[i], arg) for i, arg in enumerate(args)
]
cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()}
return cloned_args, cloned_kwargs
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
return self.maybe_clone_args(set(), *args, **kwargs)
def benchmark_all_configs(self, *args, **kwargs):
with dynamo_timed(
"CachingAutotuner.benchmark_all_configs", log_pt2_compile_event=True
):
timings = {
launcher: self.bench(launcher, *args, **kwargs)
for launcher in self.launchers
}
for k, v in timings.items():
self.coordesc_tuner.cache_benchmark_result(k.config, v)
if log.isEnabledFor(logging.DEBUG):
log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
for k, v in timings.items():
log.debug(
"%s: %f, nreg %d, nspill %d, #shared-mem %s",
k.config,
v,
k.n_regs,
k.n_spills,
k.shared,
)
self.reset_to_zero_args(*args, **kwargs)
return timings
def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
start_time = time.time_ns()
timings = self.benchmark_all_configs(*args, **kwargs)
benchmark_time_taken_ns = time.time_ns() - start_time
self.launchers = [builtins.min(timings, key=timings.get)]
self.autotune_time_taken_ns = (
self.precompile_time_taken_ns + benchmark_time_taken_ns
)
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
def save_gpu_kernel(self, grid, stream, launcher):
if callable(grid):
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
else:
grid_x, grid_y, grid_z = grid
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
assert key is not None, "kernel_name can not be None"
params = {
"mangled_name": (
launcher.bin.metadata.name
if hasattr(launcher.bin.metadata, "name")
else launcher.bin.metadata["name"]
),
"grid_x": grid_x,
"grid_y": grid_y,
"grid_z": grid_z,
"x_block": launcher.config.kwargs.get("XBLOCK", 1),
"y_block": launcher.config.kwargs.get("YBLOCK", None),
"z_block": launcher.config.kwargs.get("ZBLOCK", None),
"r_block": launcher.config.kwargs.get("RBLOCK", None),
"num_warps": (
launcher.bin.num_warps
if hasattr(launcher.bin, "num_warps")
else launcher.bin.metadata.num_warps
),
"shared_mem": (
launcher.bin.shared
if hasattr(launcher.bin, "shared")
else launcher.bin.metadata.shared
),
"stream": stream,
# User defined triton kernels will have arbitrary kwarg names
"meta": launcher.config.kwargs,
}
from torch._inductor.codecache import CudaKernelParamCache
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
binary = launcher.bin.asm[bin_type]
CudaKernelParamCache.set(key, params, binary, bin_type)
self.cuda_kernel_saved = True
def coordinate_descent_tuning(self, launcher, *args, **kwargs):
"""
Coordinate descent tuning can be run with or without max-autotune.
The only difference between these two is the starting config for coordinate_descent tuning.
E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
and max-autotune figure out C3 is the best.
Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
"""
if (
self.heuristic_type == HeuristicType.TEMPLATE
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
):
# skip triton template
return launcher
config2launcher = {launcher.config: launcher}
def benchmark_one_config(config):
with self.lock:
_, launcher = self._precompile_config(config, False)
config2launcher[config] = launcher
out = self.bench(launcher, *args, **kwargs)
log.debug(
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
launcher.config,
out,
launcher.n_regs,
launcher.n_spills,
launcher.shared,
)
return out
assert not (
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
and "RBLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
start_time = time.time_ns()
best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None
)
coordesc_time_taken_ns = time.time_ns() - start_time
best_config.found_by_coordesc = True
if self.save_cache_hook:
self.save_cache_hook(
best_config,
self.autotune_time_taken_ns + coordesc_time_taken_ns,
found_by_coordesc=True,
)
return config2launcher.get(best_config)
def run(
self, *args, grid, stream, benchmark_run=False, **kwargs
): # type:ignore[override]
if self.triton_interpret:
return self.fn[grid](
*args,
**kwargs,
**self.configs[0].kwargs,
)
if len(self.launchers) != 1:
if len(self.launchers) == 0:
start_time = time.time_ns()
self.precompile()
self.precompile_time_taken_ns = time.time_ns() - start_time
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid, **kwargs)
if not getattr(
self.launchers[0].config, "found_by_coordesc", False
) and self.inductor_meta.get("coordinate_descent_tuning", False):
self.launchers = [
self.coordinate_descent_tuning(
self.launchers[0], *args, grid=grid, **kwargs
)
]
(launcher,) = self.launchers
if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved):
self.save_gpu_kernel(grid, stream, launcher)
if self.dump_launch_params:
_dump_launch_params(args, kwargs, launcher, self.fn.__name__)
# it is faster than entering and exiting a context manager, even if the context
# manager is a nullcontext.
if autograd_profiler._is_profiler_enabled:
# grid can be a tuple of ints or a string.
if isinstance(grid, tuple):
grid_info = str(grid)
else:
grid_info = getattr(grid, "grid_fn_str", "")
with torch._C._profiler._RecordFunctionFast(
self.inductor_meta.get("kernel_name", "triton kernel"),
args,
{
"kernel_file": (self.filename or ""),
"kernel_hash": self.kernel_hash,
"kernel_backend": "triton",
"grid": grid_info,
"stream": stream,
},
):
return launcher(
*args,
**kwargs,
grid=grid,
stream=stream,
)
else:
return launcher(
*args,
**kwargs,
grid=grid,
stream=stream,
)
def _find_names(obj):
import gc
import inspect
frame = inspect.currentframe()
while frame is not None:
frame.f_locals
frame = frame.f_back
obj_names = []
for referrer in gc.get_referrers(obj):
if isinstance(referrer, dict):
for k, v in referrer.items():
if v is obj:
obj_names.append(k)
return obj_names
collected_calls: List[Any] = []
def start_graph():
collected_calls.clear()
def end_graph(output_file):
if len(collected_calls) == 0:
return
overall_time = sum(call[0] for call in collected_calls)
overall_gb = sum(call[1] for call in collected_calls)
cur_file = inspect.stack()[1].filename
summary_str = (
f"SUMMARY ({cur_file})\n"
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
)
log.info(
"%s",
summary_str,
)
if output_file is not None:
# sort perf numbers in descending order, i.e. placing the
# most runtime-heavy kernels at the top of the list
sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
try:
with open(output_file, "a") as file:
log.info(
"Save profile bandwidth results to %s",
output_file,
)
file.write("====================\n")
file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
# also display the runtime percentage for each kernel
percentage = f"{ms / overall_time * 100:.2f}%"
suffix = f" \t {percentage} \t {kernel_name}"
bw_info_str = create_bandwidth_info_str(
ms,
num_gb,
gb_per_s,
suffix=suffix,
color=False,
)
file.write(bw_info_str + "\n")
file.write(f"{summary_str}\n\n")
except Exception as e:
log.warning(
"failed to write profile bandwidth result into %s: %s",
output_file,
e,
)
class DebugAutotuner(CachingAutotuner):
def __init__(
self,
*args,
regex_filter="",
with_profiler=False,
with_bandwidth_info=True,
**kwargs,
):
self.regex_filter = regex_filter
self.with_profiler = with_profiler
self.with_bandwidth_info = with_bandwidth_info
super().__init__(*args, **kwargs)
self.cached = None
def run(self, *args, grid, stream, **kwargs):
if not self.with_bandwidth_info:
super().run(*args, grid=grid, stream=stream, **kwargs, benchmark_run=True)
return
else:
possible_names = _find_names(self)
kernel_name = f"{max(possible_names, key=len)}"
if not re.match(self.regex_filter, kernel_name):
return
if len(self.launchers) != 1:
if len(self.launchers) == 0:
start_time = time.time_ns()
self.precompile()
self.precompile_time_taken_ns = time.time_ns() - start_time
if len(self.launchers) > 1:
self.autotune_to_one_config(*args, grid=grid, **kwargs)
(launcher,) = self.launchers
if launcher.store_cubin:
self.save_gpu_kernel(grid, stream, launcher)
if self.cached is None:
ms = self.bench(
launcher, *args, grid=grid, with_profiler=self.with_profiler
)
num_in_out_ptrs = len(
[
arg_name
for arg_name in self.fn.arg_names
if arg_name.startswith("in_out_ptr")
]
)
num_gb = self.inductor_meta.get("kernel_num_gb", None)
if num_gb is None:
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
gb_per_s = num_gb / (ms / 1e3)
self.cached = ms, num_gb, gb_per_s, kernel_name
collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
log.info(
"%s",
create_bandwidth_info_str(
ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
),
)
else:
# in AOTI, we will call the kernel and its timing info has been cached already
collected_calls.append(self.cached)
def hash_configs(configs: List[Config]):
"""
Hash used to check for changes in configurations
"""
hasher = hashlib.sha256()
for cfg in configs:
hasher.update(
f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
)
return hasher.hexdigest()
def cached_autotune(
size_hints: Optional[List[int]],
configs: List[Config],
triton_meta,
heuristic_type,
filename=None,
inductor_meta=None,
custom_kernel=False,
):
"""
A copy of triton.autotune that calls our subclass. Our subclass
has additional debugging, error handling, and on-disk caching.
"""
configs = unique_configs(configs)
assert len(configs) == 1 or filename
inductor_meta = {} if inductor_meta is None else inductor_meta
disabled = inductor_meta.get("force_disable_caches", False)
# on disk caching logic and/or remote caching
autotune_cache = None
if (
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
):
configs_hash = hash_configs(configs)
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
if autotune_cache:
if best_config := autotune_cache.read_best(inductor_meta, configs):
configs = [best_config]
else:
if disabled:
log.debug("autotune caching is disabled by config.force_disable_caches")
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
optimize_mem = inductor_meta.pop("optimize_mem", True)
if "restore_value" in triton_meta:
mutated_arg_names += triton_meta.pop("restore_value")
reset_to_zero_arg_names: List[str] = []
if "reset_to_zero" in triton_meta:
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
def decorator(fn):
# Remove XBLOCK from config if it's not a function argument.
# This way, coordinate descent tuning will not try to tune it.
#
# Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
import inspect
if "XBLOCK" not in inspect.signature(fn.fn).parameters:
for tconfig in configs:
if "XBLOCK" in tconfig.kwargs:
assert tconfig.kwargs["XBLOCK"] == 1
tconfig.kwargs.pop("XBLOCK")
if inductor_meta.get("profile_bandwidth"):
return DebugAutotuner(
fn,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
regex_filter=inductor_meta["profile_bandwidth_regex"],
with_profiler=inductor_meta[
"profile_bandwidth_with_do_bench_using_profiling"
],
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
filename=filename,
with_bandwidth_info=True,
)
return CachingAutotuner(
fn,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
configs=configs,
save_cache_hook=autotune_cache and autotune_cache.save,
mutated_arg_names=mutated_arg_names,
reset_to_zero_arg_names=reset_to_zero_arg_names,
optimize_mem=optimize_mem,
heuristic_type=heuristic_type,
size_hints=size_hints,
custom_kernel=custom_kernel,
filename=filename,
)
return decorator
def unique_configs(configs: List[Config]):
"""Remove duplicate configurations"""
seen = set()
pruned_configs = []
for cfg in configs:
key = triton_config_to_hashable(cfg)
if key not in seen:
seen.add(key)
pruned_configs.append(cfg)
return pruned_configs
def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
if numel is None:
continue
block = cfg[f"{label}BLOCK"]
if numel == 1:
assert block == 1, (
f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
)
max_block = TRITON_MAX_BLOCK[label]
max_block_str = f'config.triton.max_block["{label}"]'
assert max_block % block == 0, (
f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
)
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
# therefore using half the number of warps here correspondingly.
if torch.version.hip:
max_num_warps = (max_num_warps + 1) // 2
min_num_warps = (min_num_warps + 1) // 2
# persistent reduction is register intensive
if register_intensive:
max_num_warps = max_num_warps // 2
return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps))
def _check_max_grid_x(size_hints, x, num_warps):
# Check if maxGridSize is exceeded - if so then must scale XBLOCK further
max_grid_x = 2147483647
warp_size = (
64 if torch.version.hip else 32
) # TODO: query warp size once #129663 is merged
num_blocks = (size_hints["x"] + x - 1) // x
while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]:
x *= 2 # Scale up XBLOCK if grid exceeds limits
num_blocks = num_blocks // 2
if (num_blocks * num_warps * warp_size) > max_grid_x:
raise AssertionError(
"Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue"
)
return x, num_blocks
def triton_config(
size_hints,
x,
y=None,
z=None,
num_stages=1,
num_elements_per_warp=256,
min_elem_per_thread=0,
) -> Config:
"""
Construct a pointwise triton config with some adjustment heuristics
based on size_hints. Size_hints is a tuple of numels in each tile
dimension and will be rounded up to the nearest power of 2.
num_elements_per_warp is a suggestion for controlling how many warps
the triton config should contain. e.g.: if x=16, y=8, z=4 then
num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
just a suggestion, and sometimes other adjustment heuristics will
override the num_elements_per_warp.
min_elem_per_thread controls the minimum number of elements
processed by each thread. It's always enforced.
"""
# Ideally we want to read this from some device config
maxGridSize = [2147483647, 65535, 65535]
target = conditional_product(x, y, z)
if conditional_product(*size_hints.values()) < target:
target //= 8
# shrink sizes to size hints
x = min(x, size_hints["x"])
if y:
y = min(y, size_hints["y"])
if z:
z = min(z, size_hints["z"])
# if we are below original block size, scale up where we can;
# or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and (
x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target
):
x *= 2
while (
y
and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"])
and (
y * maxGridSize[1] < size_hints["y"]
or conditional_product(x, y, z) < target
)
):
y *= 2
while (
z
and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"])
and (
z * maxGridSize[2] < size_hints["z"]
or conditional_product(x, y, z) < target
)
):
z *= 2
num_warps = _num_warps(
conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
)
# we are going to arrive at 2 warps only if bs was too small due to
# numel being too small. However to workaround some ptx bugs we still
# want at least 4 warps if there's enough elements per thread
# given that this is a rare situation, don't expect this to affect perf
# in general
# see https://github.com/pytorch/pytorch/pull/97950
if conditional_product(x, y, z) >= 128 and not torch.version.hip:
num_warps = max(num_warps, 4)
xnumel = size_hints["x"]
ynumel = size_hints.get("y")
znumel = size_hints.get("z")
# Increase x to satisfy min_elem_per_thread requirements.
block_size = max(
conditional_product(x, y, z),
min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
)
x *= math.ceil(block_size / conditional_product(x, y, z))
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
x = min(x, size_hints["x"])
cfg = {"XBLOCK": x}
if y:
cfg["YBLOCK"] = y
if z:
cfg["ZBLOCK"] = z
assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def triton_config_reduction(
size_hints, x, r, num_stages=1, num_warps=None, register_intensive=False
) -> Config:
"""
Construct a reduction triton config with some adjustment heuristics
based on size_hints. Size_hints is a tuple of numels in each tile
dimension and will be rounded up to the nearest power of 2.
"""
target = conditional_product(x, r)
if conditional_product(*size_hints.values()) < target:
target //= 8
# shrink sizes to size hints
x = min(x, size_hints["x"])
r = min(r, size_hints["r"])
# if we are below original block size, scale up where we can
while x < size_hints["x"] and conditional_product(x, r) < target:
x *= 2
while r < size_hints["r"] and conditional_product(x, r) < target:
r *= 2
if num_warps is None:
num_warps = conditional_product(x, r) // 128
num_warps = _num_warps(
num_warps, max_num_warps=16, register_intensive=register_intensive
)
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
while conditional_product(x, r) > target:
if r == 1:
break
r = r // 2
cfg = {"XBLOCK": x, "RBLOCK": r}
check_config(cfg, xnumel=size_hints["x"])
assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
"""
Construct a tile reduction triton config with some adjustment
heuristics based on size_hints. Size_hints is a tuple of numels in
each tile dimension and will be rounded up to the nearest power of 2.
"""
target = conditional_product(x, y, r)
if conditional_product(*size_hints) < target:
target //= 8
# shrink sizes to size hints
x = min(x, size_hints["x"])
y = min(y, size_hints["y"])
r = min(r, size_hints["r"])
# if we are below original block size, scale up where we can
while x < size_hints["x"] and conditional_product(x, y, r) < target:
x *= 2
while r < size_hints["r"] and conditional_product(x, y, r) < target:
r *= 2
while y < size_hints["y"] and conditional_product(x, y, r) < target:
y *= 2
cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
num_warps = _num_warps(conditional_product(x, y, r) // 256, min_num_warps=1)
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def pointwise(
size_hints,
triton_meta,
tile_hint=None,
filename=None,
min_elem_per_thread=0,
inductor_meta=None,
):
"""
Construct @triton.heuristics() based on size_hints.
"""
inductor_meta = {} if inductor_meta is None else inductor_meta
assert not inductor_meta.get("no_x_dim")
numel = functools.reduce(operator.mul, size_hints.values())
bs = max(256, min(numel // 128, 1024))
hinted_configs = autotune_hints_to_configs(
inductor_meta.get("autotune_hints", set()),
size_hints,
bs,
triton_meta["device"],
)
triton_config_with_settings = functools.partial(
triton_config, min_elem_per_thread=min_elem_per_thread
)
configs = None
if len(size_hints) == 1:
if disable_pointwise_autotuning(inductor_meta) and not (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
):
configs = [triton_config_with_settings(size_hints, bs)]
else:
configs = [
triton_config_with_settings(size_hints, bs, num_elements_per_warp=256),
triton_config_with_settings(
size_hints, bs // 2, num_elements_per_warp=64
),
*hinted_configs,
]
if len(size_hints) == 2:
if (
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
) and not (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
):
configs = [triton_config_with_settings(size_hints, 32, 32)]
else:
configs = [
triton_config_with_settings(size_hints, 32, 32),
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
triton_config_with_settings(size_hints, 256, 16),
triton_config_with_settings(size_hints, 16, 256),
triton_config_with_settings(size_hints, bs, 1),
triton_config_with_settings(size_hints, 1, bs),
*hinted_configs,
]
if len(size_hints) == 3:
if disable_pointwise_autotuning(inductor_meta):
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
else:
configs = [
triton_config_with_settings(size_hints, 16, 16, 16),
triton_config_with_settings(size_hints, 64, 8, 8),
triton_config_with_settings(size_hints, 8, 64, 8),
triton_config_with_settings(size_hints, 8, 8, 64),
triton_config_with_settings(size_hints, bs, 1, 1),
triton_config_with_settings(size_hints, 1, bs, 1),
triton_config_with_settings(size_hints, 1, 1, bs),
*hinted_configs,
]
if not configs:
raise NotImplementedError(f"size_hints: {size_hints}")
return cached_autotune(
size_hints,
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.POINTWISE,
filename=filename,
)
def _reduction_configs(
*, size_hints: Dict[str, int], inductor_meta: Dict[str, Any]
) -> List[Config]:
reduction_hint = inductor_meta.get("reduction_hint", None)
assert len(size_hints) == 2
rnumel = size_hints["r"]
register_intensive = False
MAX_RBLOCK = 2048
if (
size_hints["x"] >= 1024
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
>= 10
):
# A heuristics to reduce RBLOCK if a kernel potentially need many registers.
# Consider load and reduction since load need move data into registers and
# reduction needs an accumulator.
#
# The magic numbers are a bit arbitrary.
#
# We cannot rely on dynamically scaling down RBLOCK later, since sometimes
# triton makes it to use less registers with worse perf. Check:
# https://github.com/pytorch/pytorch/issues/126463
#
# The heuristic is a very simple one since registers can be reused. But
# hopefully it can be a good enough indicator.
MAX_RBLOCK = 1024
register_intensive = True
contiguous_config = triton_config_reduction(
size_hints,
1,
(rnumel if 256 <= rnumel < MAX_RBLOCK else MAX_RBLOCK),
register_intensive=register_intensive,
)
outer_config = triton_config_reduction(
size_hints, 64, 8, register_intensive=register_intensive
)
tiny_config = triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
min(rnumel, MAX_RBLOCK),
register_intensive=register_intensive,
)
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
pass # skip all these cases
elif reduction_hint == ReductionHint.INNER:
return [contiguous_config]
elif reduction_hint == ReductionHint.OUTER:
return [outer_config]
elif reduction_hint == ReductionHint.OUTER_TINY:
return [tiny_config]
if disable_pointwise_autotuning(inductor_meta):
return [triton_config_reduction(size_hints, 32, 128)]
return [
contiguous_config,
outer_config,
tiny_config,
triton_config_reduction(size_hints, 64, 64),
triton_config_reduction(size_hints, 8, 512),
# halve the XBLOCK/RBLOCK compared to outer_config
# TODO: this may only be beneficial when each iteration of the reduction
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
triton_config_reduction(size_hints, 64, 4, num_warps=8),
]
def reduction(
size_hints,
reduction_hint=False,
triton_meta=None,
filename=None,
inductor_meta=None,
):
"""args to @triton.heuristics()"""
inductor_meta = {} if inductor_meta is None else inductor_meta
inductor_meta["reduction_hint"] = reduction_hint
if inductor_meta.get("no_x_dim"):
size_hints["x"] = 1
assert triton_meta is not None
if len(size_hints) != 2:
raise NotImplementedError(f"size_hints: {size_hints}")
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
return cached_autotune(
size_hints,
configs=configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.REDUCTION,
filename=filename,
)
def cooperative_reduction(
size_hints,
reduction_hint,
triton_meta,
filename,
inductor_meta,
):
inductor_meta = {} if inductor_meta is None else inductor_meta
inductor_meta["reduction_hint"] = reduction_hint
if inductor_meta.get("no_x_dim"):
size_hints["x"] = 1
xnumel, rnumel = size_hints["x"], size_hints["r"]
# TODO(jansel): we should base target on the SM count of the local GPU
target = 64
split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT))
assert rnumel >= split
assert split <= TRITON_MAX_RSPLIT
if inductor_meta["persistent_reduction"]:
configs = _persistent_reduction_configs(
{"x": xnumel, "r": rnumel // split}, reduction_hint, inductor_meta
)
else:
configs = _reduction_configs(
size_hints={"x": xnumel, "r": rnumel // split}, inductor_meta=inductor_meta
)
for config in configs:
config.kwargs["RSPLIT"] = split
# TODO(jansel): add more configs in max_autotune
return cached_autotune(
size_hints,
configs=configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.REDUCTION,
filename=filename,
)
def _persistent_reduction_configs(
size_hints,
reduction_hint=False,
inductor_meta=None,
):
xnumel, rnumel = size_hints["x"], size_hints["r"]
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
for xblock in (1, 8, 32, 128)
if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel)
]
# TODO(jansel): we should be able to improve these heuristics
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [
triton_config_reduction(
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
)
]
for c in configs:
# we don't need RBLOCK for persistent reduction
c.kwargs.pop("RBLOCK")
if disable_pointwise_autotuning(inductor_meta):
configs = configs[:1]
return configs
def persistent_reduction(
size_hints,
reduction_hint=False,
triton_meta=None,
filename=None,
inductor_meta=None,
):
inductor_meta = {} if inductor_meta is None else inductor_meta
inductor_meta["reduction_hint"] = reduction_hint
if inductor_meta.get("no_x_dim"):
size_hints["x"] = 1
configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta)
return cached_autotune(
size_hints,
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
filename=filename,
heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
)
def split_scan(
size_hints,
reduction_hint=False,
triton_meta=None,
filename=None,
inductor_meta=None,
):
"""Heuristic for TritonSplitScanKernel"""
inductor_meta = {} if inductor_meta is None else inductor_meta
inductor_meta["reduction_hint"] = reduction_hint
if inductor_meta.get("no_x_dim"):
size_hints["x"] = 1
assert triton_meta is not None
if len(size_hints) != 2:
raise NotImplementedError(f"size_hints: {size_hints}")
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
# Fixup configs to enforce the minimum RBLOCK size
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
for cfg in configs:
if cfg.kwargs["RBLOCK"] < min_rblock:
cfg.kwargs["RBLOCK"] = min_rblock
return cached_autotune(
size_hints,
configs=configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.SPLIT_SCAN,
filename=filename,
)
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton template
"""
return cached_autotune(
None,
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)
def _pop_config_kwargs(config: Dict[str, Any]) -> Dict[str, Any]:
"""Extract triton.Config options that should become kwargs"""
popped = {}
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
val = config.pop(key, None)
if val is not None:
popped[key] = val
return popped
def fixed_config(config, filename, triton_meta, inductor_meta):
"""
Used when the configuration is already decided at compile time
"""
config = {**config}
return cached_autotune(
None,
[triton.Config(config, **_pop_config_kwargs(config))],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.FIXED,
filename=filename,
)
def user_autotune(
configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
):
"""
Compile a user defined triton kernel
"""
if len(configs) == 0:
configs = [triton.Config({})]
else:
configs = [
triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c}))
for c in configs
]
return cached_autotune(
None,
configs,
triton_meta=triton_meta,
heuristic_type=HeuristicType.USER_AUTOTUNE,
filename=filename,
inductor_meta=inductor_meta,
custom_kernel=custom_kernel,
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)
def grid(*numels):
"""Helper function to compute triton grids"""
if len(numels) == 1:
xnumel, ynumel, znumel = numels[0], None, None
elif len(numels) == 2:
xnumel, ynumel, znumel = numels[1], numels[0], None
elif len(numels) == 3:
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
else:
raise AssertionError(f"invalid size for numels {len(numels)}")
def get_grid_dim(numel, block):
if numel is None:
return 1
if block is None:
return numel
return ceildiv(numel, block)
def grid_fn(meta):
x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
max_y_grid = get_max_y_grid()
if znumel is None:
div = ceildiv(y_grid, max_y_grid)
y_grid = ceildiv(y_grid, div)
z_grid = div
else:
z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
torch._check(
y_grid <= max_y_grid,
lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
)
return (
x_grid,
y_grid,
z_grid,
)
setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010
return grid_fn
def cooperative_reduction_grid(xnumel):
def grid_fn(meta):
return (meta["RSPLIT"], ceildiv(xnumel, meta.get("XBLOCK", 1)), 1)
grid_fn_str = f"cooperative_reduction_grid({xnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def maybe_cooperative_reduction_grid(xnumel):
def grid_fn(meta):
if "RSPLIT" in meta:
return coop_grid(meta)
return normal_grid(meta)
coop_grid = cooperative_reduction_grid(xnumel)
normal_grid = grid(xnumel)
grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def split_scan_grid(xnumel, rnumel):
def grid_fn(meta):
assert meta.get("XBLOCK", 1) == 1
return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1)
grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def grid_combo_kernels(
*numels, num_kernels, min_blocks, is_sequential, default_meta=None
):
"""min_blocks is the minimal size of the grid x dimension"""
if not is_sequential:
# round robin dispatch
numels_agg = list(numels)
for i in range(len(numels_agg)):
if isinstance(numels_agg[i], (list, tuple)):
numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301
kernel_grid_fn = grid(*numels_agg)
if isinstance(numels[-1], (list, tuple)):
min_blocks_d = max(-min(numels[-1]), 0) * num_kernels
else:
min_blocks_d = None
if min_blocks is None:
assert min_blocks_d is not None
min_blocks = min_blocks_d
else:
assert (
min_blocks_d is None or min_blocks == min_blocks_d
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
else:
# sequential dispatch
seq_numels = list(numels)
# x numels are not used here, just a place holder
seq_numels[-1] = 1024
for i in range(len(seq_numels) - 1):
if isinstance(seq_numels[i], (list, tuple)):
seq_numels[i] = max(seq_numels[i])
kernel_grid_fn = grid(*seq_numels)
def get_grid_dim(numel, block):
if numel is None:
return 1
if block is None:
return numel
return ceildiv(numel, block)
def grid_fn(meta):
assert min_blocks is not None, "min_blocks must be a number"
cuda_grid = list(kernel_grid_fn(meta))
cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
return tuple(cuda_grid)
def seq_grid_fn(meta):
cuda_grid = list(kernel_grid_fn(meta))
# x <= 0 means this kernel's x grid is not tunable (x_no_dim is true)
x_grid = sum(
[
-x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1))
for x in numels[-1]
]
)
cuda_grid[0] = x_grid
return tuple(cuda_grid)
def grid_fn_default_meta(meta):
return grid_fn(default_meta)
def seq_grid_fn_default_meta(meta):
return seq_grid_fn(default_meta)
if default_meta is None:
return grid_fn if not is_sequential else seq_grid_fn
else:
return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta
|