1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776
|
import contextlib
import itertools
import math
import operator
import weakref
from enum import Enum
from functools import partial, reduce
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
import torch
import torch._prims_common as utils
import torch.library
from torch import Tensor, TypedStorage
from torch._C import _get_default_device
from torch._prims.nvfuser_prims import register_nvprims
from torch._prims_common import (
check,
DimsSequenceType,
DimsType,
Number,
NumberType,
RETURN_TYPE,
ShapeType,
StrideType,
TensorLike,
TensorLikeType,
type_to_dtype,
)
from torch._prims_common.wrappers import backwards_not_supported
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.overrides import handle_torch_function, has_torch_function
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
prim = torch.library.Library("prims", "DEF")
prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
# Experimental module containing prototype "primitive" operations.
__all__ = [
#
# Common datastructures and helpers
#
"RETURN_TYPE",
#
# Elementwise unary prims
#
"abs",
"acos",
"acosh",
"asin",
"asinh",
"atan",
"atanh",
"cos",
"cosh",
"bessel_i0",
"bessel_i0e",
"bessel_i1",
"bessel_i1e",
"bessel_j0",
"bessel_j1",
"bitwise_not",
"cbrt",
"ceil",
"conj_physical",
"digamma",
"erf",
"erf_inv",
"erfc",
"exp",
"expm1",
"exp2",
"fill",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"neg",
"real",
"reciprocal",
"round",
"sign",
"signbit",
"sin",
"sinh",
"spherical_bessel_j0",
"sqrt",
"tan",
"tanh",
"trunc",
#
# Elementwise binary prims
#
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
# 'complex', # needs custom meta
"div",
"eq",
"fmax",
"fmin",
"fmod",
"gcd",
"ge",
"gt",
"hypot",
"igamma",
"igammac",
"le",
"lt",
"maximum",
"minimum",
"mul",
"ne",
"nextafter",
"pow",
"remainder",
"rsqrt",
"shift_left",
"shift_right_arithmetic",
"shift_right_logical", # not implemented
"sub",
"zeta",
#
# View prims
#
"as_strided",
"broadcast_in_dim",
"collapse_view",
"conj",
"expand_dims",
"slice",
"slice_in_dim", # implemented using slice -- make this a ref?
"split_dim",
"squeeze",
"transpose",
"view_of",
#
# Shape prims
#
"collapse",
"cat",
"reshape",
"rev",
#
# Conditional prims
#
"where",
#
# Data conversion and movement prims
#
"convert_element_type",
"device_put",
"item",
"maximum_value",
"minimum_value",
"to_dtype",
#
# Inplace prims
#
"copy_to",
"resize",
# "_set", # Commented out, see note below
#
# Reduction prims
#
"amax",
"amin",
"prod",
"sum",
"var",
#
# Tensor Creation Prims
#
"empty_strided",
"scalar_tensor",
"arange",
#
# Linear algebra (linalg) Prims
#
"svd",
#
# Randomness Prims
#
"normal",
"uniform",
#
# FFT prims
#
"fft_r2c",
"fft_c2c",
"fft_c2r",
]
def TensorMeta(
tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
*,
shape: Optional[ShapeType] = None,
strides: Optional[StrideType] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
):
if isinstance(tensorlike, Number):
assert not shape and (shape is None or isinstance(shape, Sequence))
assert not strides and (strides is None or isinstance(strides, Sequence))
inferred_shape: Tuple[int, ...] = ()
inferred_strides: Tuple[int, ...] = ()
inferred_dtype = type_to_dtype(type(tensorlike))
inferred_device = torch.device("cpu")
# TODO: This looks wrong, a number that is wrapped into a tensor
# needs to behave differently than a scalar tensor for type
# promotion purposes
elif tensorlike is not None:
assert isinstance(tensorlike, torch.Tensor)
inferred_shape = tuple(tensorlike.shape)
inferred_strides = tuple(tensorlike.stride())
inferred_dtype = tensorlike.dtype
inferred_device = tensorlike.device
else:
# If no tensorlike "example" is given then all metadata
# must be provided explicitly
assert shape is not None
assert strides is not None
assert dtype is not None
assert device is not None
shape = inferred_shape if shape is None else tuple(shape)
strides = inferred_strides if strides is None else tuple(strides)
dtype = inferred_dtype if dtype is None else dtype
device = inferred_device if device is None else device
if isinstance(device, str):
device = torch.device(device)
return torch.empty_strided(shape, strides, dtype=dtype, device=device)
def _make_prim(
*,
schema: str,
return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
meta: Callable,
impl_aten: Callable,
doc: str,
):
"""
Creates a primitive operation.
"""
prim.define(schema)
def _prim_impl(*args, **kwargs):
# always run the meta function because aten implementation will
# typically accept more inputs (e.g., it will do promotion and
# broadcasting) which we want to reject
meta(*args, **kwargs)
return impl_aten(*args, **kwargs)
# Right now prims don't support autograd (we can and should add an
# argument that provides an implementation for backward here.) Because we
# don't have derivative formulas, we must setup a custom autograd function
# that raises an error if backwards is invoked
def _autograd_impl(*args, **kwargs):
return backwards_not_supported(_prim)(*args, **kwargs)
def _backend_select_impl(*args, **kwargs):
if kwargs.get("device") and kwargs["device"].type == "meta":
return meta(*args, **kwargs)
else:
return _prim_impl(*args, **kwargs)
name = schema.split("(")[0]
prim_impl.impl(name, _prim_impl)
prim_autograd_impl.impl(name, _autograd_impl)
prim_meta_impl.impl(name, meta)
_prim_packet = getattr(torch.ops.prims, name)
_prim = _prim_packet.default
from torch._subclasses.fake_tensor import contains_tensor_types
if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments):
prim_backend_select_impl.impl(name, _backend_select_impl)
for p in (_prim_packet, _prim):
p.__doc__ = doc
p.return_type = return_type # type: ignore[attr-defined]
p.schema = schema
p.prim_impl = _prim_impl
p.prim_meta_impl = meta
return _prim
class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
DEFAULT = (0,)
ALWAYS_BOOL = (2,)
COMPLEX_TO_FLOAT = (3,)
# TODO: implement dtype validation here, too, or on the corresponding refs
def _elementwise_meta(
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
) -> FakeTensor:
"""
Meta function for elementwise operations that produce outputs in the same dtype
as their inputs.
Stride logic is currently incorrect.
"""
assert len(args) > 0
utils.check_same_dtype(*args)
args_ = list(args)
if args_with_fixed_dtypes is not None:
args_.extend(args_with_fixed_dtypes)
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
strides = utils.compute_elementwise_output_strides(*args_)
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
# Acquires the dtype
dtype = None
scalar_type = None
for arg in args:
if isinstance(arg, TensorLike):
if not utils.is_cpu_scalar_tensor(arg):
dtype = arg.dtype
break
else:
dtype = arg.dtype
elif isinstance(arg, Number):
scalar_type = type(arg)
if dtype is None and scalar_type is not None:
dtype = utils.type_to_dtype(scalar_type)
# Acquires the device (if it exists) or number
device = None
number = None
for arg in args_:
if isinstance(arg, TensorLike):
device = arg.device
break
elif isinstance(arg, Number):
if number is None:
number = arg
# NOTE: type promotion behavior here is mostly hidden from tests because
# references will typically handle the type promotion properly even if this doesn't
# (but getting it wrong will cause too many casts to be inserted in traces!)
if device is not None:
assert dtype is not None
if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
dtype = dtype
elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
dtype = torch.bool
elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
if utils.is_complex_dtype(dtype):
dtype = utils.corresponding_real_dtype(dtype)
else:
dtype = dtype
return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype)
# Number case
# NOTE: this case is not currently exercised
# TODO: fix number type promotion (bool, complex->float)
assert not isinstance(number, torch.SymIntNode), "NYI"
assert not isinstance(number, torch.SymFloatNode), "NYI"
return TensorMeta(number)
def _complex_only_elementwise_meta(*args, **kwargs):
utils.check(
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
)
return _elementwise_meta(*args, **kwargs)
def _make_elementwise_unary_prim(
name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
):
"""
Creates an elementwise unary prim.
"""
return _make_prim(
schema=f"{name}(Tensor self) -> Tensor",
meta=partial(_elementwise_meta, type_promotion=type_promotion),
return_type=RETURN_TYPE.NEW,
**kwargs,
)
def _make_elementwise_binary_prim(
name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
):
"""
Creates an elementwise binary prim.
"""
return _make_prim(
schema=f"{name}(Tensor self, Tensor other) -> Tensor",
meta=partial(_elementwise_meta, type_promotion=type_promotion),
return_type=RETURN_TYPE.NEW,
**kwargs,
)
def _not_impl(*args, **kwargs):
raise NotImplementedError
#
# Elementwise unary operations
#
abs = _make_elementwise_unary_prim(
"abs",
impl_aten=torch.abs,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
)
acos = _make_elementwise_unary_prim(
"acos",
impl_aten=torch.acos,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
acosh = _make_elementwise_unary_prim(
"acosh",
impl_aten=torch.acosh,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
asin = _make_elementwise_unary_prim(
"asin",
impl_aten=torch.asin,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
asinh = _make_elementwise_unary_prim(
"asinh",
impl_aten=torch.asinh,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
atan = _make_elementwise_unary_prim(
"atan",
impl_aten=torch.atan,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
atanh = _make_elementwise_unary_prim(
"atanh",
impl_aten=torch.atanh,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
cos = _make_elementwise_unary_prim(
"cos",
impl_aten=torch.cos,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
cosh = _make_elementwise_unary_prim(
"cosh",
impl_aten=torch.cosh,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bessel_j0 = _make_elementwise_unary_prim(
"bessel_j0",
impl_aten=torch.special.bessel_j0,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bessel_j1 = _make_elementwise_unary_prim(
"bessel_j1",
impl_aten=torch.special.bessel_j1,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bessel_i0 = _make_elementwise_unary_prim(
"bessel_i0",
impl_aten=torch.i0,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bessel_i0e = _make_elementwise_unary_prim(
"bessel_i0e",
impl_aten=torch.special.i0e,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bessel_i1 = _make_elementwise_unary_prim(
"bessel_i1",
impl_aten=torch.special.i1,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bessel_i1e = _make_elementwise_unary_prim(
"bessel_i1e",
impl_aten=torch.special.i1e,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bitwise_not = _make_elementwise_unary_prim(
"bitwise_not",
impl_aten=torch.bitwise_not,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
def _cbrt_aten(a: torch.Tensor) -> Tensor:
utils.check(
not a.is_complex(),
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
)
# Returns the real cubic root of the number.
# Note that if a < 0, pow(a, (1. / 3.)) returns th complex number
# exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i}
# which is a complex number.
# For more info see the section Note in
# https://en.cppreference.com/w/cpp/numeric/math/cbrt
return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
cbrt = _make_elementwise_unary_prim(
"cbrt",
impl_aten=_cbrt_aten,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
ceil = _make_elementwise_unary_prim(
"ceil",
impl_aten=torch.ceil,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
if not input.dtype.is_complex:
raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
strides = utils.compute_elementwise_output_strides(input)
return TensorMeta(input, strides=strides)
conj_physical = _make_prim(
schema="conj_physical(Tensor self) -> Tensor",
meta=_conj_physical_meta,
impl_aten=torch._conj_physical,
doc="Returns the physical conjugation of a complex tensor",
return_type=RETURN_TYPE.NEW,
)
digamma = _make_elementwise_unary_prim(
"digamma",
impl_aten=torch.digamma,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
erf = _make_elementwise_unary_prim(
"erf",
impl_aten=torch.erf,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
erf_inv = _make_elementwise_unary_prim(
"erf_inv",
impl_aten=torch.special.erfinv,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
erfc = _make_elementwise_unary_prim(
"erfc",
impl_aten=torch.special.erfc,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
exp = _make_elementwise_unary_prim(
"exp",
impl_aten=torch.exp,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
expm1 = _make_elementwise_unary_prim(
"expm1",
impl_aten=torch.special.expm1,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
exp2 = _make_elementwise_unary_prim(
"exp2",
impl_aten=torch.special.exp2,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
return _elementwise_meta(
a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
)
# See https://github.com/pytorch/pytorch/issues/77932 for out-of-place fill request
def _fill_aten(a: Tensor, value: NumberType) -> Tensor:
t = a * False
with torch.no_grad():
t.fill_(value) # type: ignore[arg-type]
return t
# NOTE: fill uses _make_prim directly because it has a value parameter
fill = _make_prim(
schema="fill(Tensor self, Scalar value) -> Tensor",
return_type=RETURN_TYPE.NEW,
meta=_fill_meta,
impl_aten=_fill_aten,
doc="",
)
floor = _make_elementwise_unary_prim(
"floor",
impl_aten=torch.floor,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
imag = _make_prim(
schema="imag(Tensor self) -> Tensor",
meta=partial(
_complex_only_elementwise_meta,
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
),
return_type=RETURN_TYPE.VIEW,
impl_aten=torch.imag,
doc="",
)
isfinite = _make_elementwise_unary_prim(
"isfinite",
impl_aten=torch.isfinite,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
lgamma = _make_elementwise_unary_prim(
"lgamma",
impl_aten=torch.lgamma,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
log = _make_elementwise_unary_prim(
"log",
impl_aten=torch.log,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
log1p = _make_elementwise_unary_prim(
"log1p",
impl_aten=torch.log1p,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
log2 = _make_elementwise_unary_prim(
"log2",
impl_aten=torch.log2,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
log10 = _make_elementwise_unary_prim(
"log10",
impl_aten=torch.log10,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
real = _make_prim(
schema="real(Tensor self) -> Tensor",
meta=partial(
_complex_only_elementwise_meta,
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
),
return_type=RETURN_TYPE.VIEW,
impl_aten=torch.real,
doc="",
)
reciprocal = _make_elementwise_unary_prim(
"reciprocal",
impl_aten=torch.reciprocal,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
neg = _make_elementwise_unary_prim(
"neg",
impl_aten=torch.neg,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
round = _make_elementwise_unary_prim(
"round",
impl_aten=torch.round,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
rsqrt = _make_elementwise_unary_prim(
"rsqrt",
impl_aten=torch.rsqrt,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
sign = _make_elementwise_unary_prim(
"sign",
impl_aten=torch.sign,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
signbit = _make_elementwise_unary_prim(
"signbit",
impl_aten=torch.signbit,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
sin = _make_elementwise_unary_prim(
"sin",
impl_aten=torch.sin,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
sinh = _make_elementwise_unary_prim(
"sinh",
impl_aten=torch.sinh,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
spherical_bessel_j0 = _make_elementwise_unary_prim(
"spherical_bessel_j0",
impl_aten=torch.special.spherical_bessel_j0,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
sqrt = _make_elementwise_unary_prim(
"sqrt",
impl_aten=torch.sqrt,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
tan = _make_elementwise_unary_prim(
"tan",
impl_aten=torch.tan,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
tanh = _make_elementwise_unary_prim(
"tanh",
impl_aten=torch.tanh,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
trunc = _make_elementwise_unary_prim(
"trunc",
impl_aten=torch.trunc,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
#
# Elementwise binary operations
#
add = _make_elementwise_binary_prim(
name="add",
impl_aten=torch.add,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
atan2 = _make_elementwise_binary_prim(
name="atan2",
impl_aten=torch.atan2,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bitwise_and = _make_elementwise_binary_prim(
"bitwise_and",
impl_aten=torch.bitwise_and,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bitwise_or = _make_elementwise_binary_prim(
"bitwise_or",
impl_aten=torch.bitwise_or,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
bitwise_xor = _make_elementwise_binary_prim(
"bitwise_xor",
impl_aten=torch.bitwise_xor,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
# TODO: complex needs a special meta to account for its float -> complex behavior
# complex = _make_elementwise_binary_prim(
# impl_aten=torch.complex,
# doc="",
# )
# div prim performs truncation division on integer inputs
# and true division for floating and complex inputs
def _div_aten(a, b):
is_integral = isinstance(a, (bool, int)) or (
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
)
if is_integral:
return torch.div(a, b, rounding_mode="trunc")
else:
return torch.true_divide(a, b)
div = _make_elementwise_binary_prim(
"div",
impl_aten=_div_aten,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
eq = _make_elementwise_binary_prim(
"eq",
impl_aten=torch.eq,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
fmax = _make_elementwise_binary_prim(
"fmax",
impl_aten=torch.fmax,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
fmin = _make_elementwise_binary_prim(
"fmin",
impl_aten=torch.fmin,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
fmod = _make_elementwise_binary_prim(
"fmod",
impl_aten=torch.fmod,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
gcd = _make_elementwise_binary_prim(
"gcd",
impl_aten=torch.gcd,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
ge = _make_elementwise_binary_prim(
"ge",
impl_aten=torch.ge,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
gt = _make_elementwise_binary_prim(
"gt",
impl_aten=torch.gt,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
hypot = _make_elementwise_binary_prim(
"hypot",
impl_aten=torch.hypot,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
igamma = _make_elementwise_binary_prim(
"igamma",
impl_aten=torch.special.gammainc,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
igammac = _make_elementwise_binary_prim(
"igammac",
impl_aten=torch.special.gammaincc,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
le = _make_elementwise_binary_prim(
"le",
impl_aten=torch.le,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
lt = _make_elementwise_binary_prim(
"lt",
impl_aten=torch.lt,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
# Note: the following impls are because torch.maximum and torch.mininum do not support scalar inputs
def _maximum_aten(
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
) -> TensorLikeType:
if isinstance(a, TensorLike) and isinstance(b, Number):
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(b, TensorLike) and isinstance(a, Number):
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
return torch.maximum(a, b) # type: ignore[arg-type]
maximum = _make_elementwise_binary_prim(
"maximum",
impl_aten=_maximum_aten,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
def _minimum_aten(
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
) -> TensorLikeType:
if isinstance(a, TensorLike) and isinstance(b, Number):
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(b, TensorLike) and isinstance(a, Number):
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
return torch.minimum(a, b) # type: ignore[arg-type]
minimum = _make_elementwise_binary_prim(
"minimum",
impl_aten=_minimum_aten,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
mul = _make_elementwise_binary_prim(
"mul",
impl_aten=torch.mul,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
ne = _make_elementwise_binary_prim(
"ne",
impl_aten=torch.ne,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
nextafter = _make_elementwise_binary_prim(
"nextafter",
impl_aten=torch.nextafter,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
pow = _make_elementwise_binary_prim(
"pow",
impl_aten=torch.pow,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
remainder = _make_elementwise_binary_prim(
"remainder",
impl_aten=torch.remainder,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
shift_left = _make_elementwise_binary_prim(
"shift_left",
impl_aten=torch.bitwise_left_shift,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
shift_right_arithmetic = _make_elementwise_binary_prim(
"shift_right_arithmetic",
impl_aten=torch.bitwise_right_shift,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
shift_right_logical = _not_impl
sub = _make_elementwise_binary_prim(
"sub",
impl_aten=torch.sub,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
zeta = _make_elementwise_binary_prim(
"zeta",
impl_aten=torch.special.zeta,
doc="",
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
)
#
# View operations
#
# TODO: model view relationships
# TODO: model storage
def _as_strided_meta(
a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
) -> TensorLikeType:
assert len(size) == len(stride)
assert storage_offset >= 0
utils.validate_strides(stride)
utils.validate_shape(size)
if reduce(operator.mul, size) == 0:
# NOTE: This special case is to avoid having to acquire the storage below
# as_strided to shapes with no elements are trivially valid, so it's OK
pass
elif isinstance(a, torch.Tensor):
utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset)
return TensorMeta(a, shape=size, strides=stride)
def _as_strided_aten(
a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
) -> Tensor:
return torch.as_strided(a, size, stride, storage_offset)
_as_strided_doc = """
Creates a view of the tensor with the given shape (size), strides (stride) and
storage offset (storage_offset).
"""
as_strided = _make_prim(
schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
meta=_as_strided_meta,
impl_aten=_as_strided_aten,
return_type=RETURN_TYPE.VIEW,
doc=_as_strided_doc,
)
def _broadcast_in_dim_meta(
a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
):
# Type checks
assert isinstance(a, TensorLike)
assert isinstance(shape, Sequence)
assert isinstance(broadcast_dimensions, Sequence)
# every dimension must be accounted for
assert a.ndim == len(broadcast_dimensions)
# broadcast shape must have weakly more dimensions
assert len(shape) >= a.ndim
# broadcast_dimensions must be an ascending sequence
# (no relative reordering of dims) of integers and
# each dimension must be within the new shape
def _greater_than_reduce(acc, x):
assert isinstance(x, int)
assert x > acc
assert x < len(shape)
return x
reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1)
# shape must be broadcastable to
for idx, new_idx in enumerate(broadcast_dimensions):
assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx]
new_strides = []
original_idx = 0
for idx in range(len(shape)):
if idx in broadcast_dimensions:
# Assigns a stride of zero to dimensions
# which were actually broadcast
if a.shape[original_idx] != shape[idx]:
new_strides.append(0)
else:
new_strides.append(a.stride()[original_idx])
original_idx = original_idx + 1
else:
new_strides.append(0)
return TensorMeta(a, shape=shape, strides=new_strides)
def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
s = list(shape)
for broadcast_dimension in broadcast_dimensions:
s[broadcast_dimension] = -1
v = a
for idx, x in enumerate(s):
if x != -1:
v = v.unsqueeze(idx)
return v.expand(shape)
_broadcast_in_dim_doc = """
Creates a view of a with the specified shape.
Allows adding dimensions of any length and broadcasting
dimensions of length one in a to any length.
The location of the broadcast dimensions must be specified
using the broadcast_dimensions argument. Changing the
relative order of dimensions is not supported.
"""
broadcast_in_dim = _make_prim(
schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
meta=_broadcast_in_dim_meta,
impl_aten=_broadcast_in_dim_aten,
return_type=RETURN_TYPE.VIEW,
doc=_broadcast_in_dim_doc,
)
def _collapse_view_helper(
a: TensorLikeType, start: int, end: int
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
assert isinstance(a, TensorLike)
# Special-case for zero dimensional tensors
if a.ndim == 0:
shape = (1,)
strides = (1,)
else:
shape = a.shape # type: ignore[assignment]
strides = a.stride()
utils.validate_idx(len(shape), start)
utils.validate_exclusive_idx(len(shape), end)
# Verifies end is strictly greater than start
# (Collapse requires a non-empty interval)
if end <= start:
msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
end, start
)
raise ValueError(msg)
if a.ndim == 0 or (end - 1 == start):
return shape, strides
length = shape[end - 1]
stride = strides[end - 1]
for idx in reversed(range(start, end - 1)):
if shape[idx] == 0 or shape[idx + 1] == 0:
length = 0
stride = 0
break
if shape[idx] == 1:
continue
length = length * shape[idx]
stride = min(stride, strides[idx])
if (
a.numel() > 0
and shape[idx + 1] != 1
and not (strides[idx] == strides[idx + 1] * shape[idx + 1])
):
return None, None
new_shape = shape[:start] + (length,) + shape[end:]
new_strides = strides[:start] + (stride,) + strides[end:]
# NOTE: when the input has no elements it's restrided as if it were contiguous
if a.numel() == 0:
new_strides = utils.make_contiguous_strides_for(new_shape)
return new_shape, new_strides
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
new_shape, new_strides = _collapse_view_helper(a, start, end)
if new_shape is None:
msg = "Attempting to view a collapsed tensor, but no such view exists!"
raise ValueError(msg)
return TensorMeta(a, shape=new_shape, strides=new_strides)
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
# Special-cases zero-dim tensors
if a.ndim == 0:
shape = (1,)
else:
shape = a.shape # type: ignore[assignment]
dim_length = 1
for idx in range(start, end):
dim_length = dim_length * shape[idx]
new_shape = shape[0:start] + (dim_length,) + shape[end:]
return a.view(new_shape)
_collapse_view_doc = """
Creates a view of a with the dimensions between
start (inclusive) and end (exclusive) merged into a
single dimension.
If it's not possible to take such a view then an error
is thrown. See collapse instead.
The dimensions can be merged if and only if
they are all "nested" with each other. That is, they all
have the property that
stride[i] = stride[i+1] * shape[i+1]
for all i in [start, end - 1).
"""
collapse_view = _make_prim(
schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
meta=_collapse_view_meta,
impl_aten=_collapse_view_aten,
return_type=RETURN_TYPE.VIEW,
doc=_collapse_view_doc,
)
def _conj_meta(a: TensorLikeType) -> TensorLikeType:
if not a.dtype.is_complex:
raise RuntimeError("Expected complex dtype in prims.conj")
return TensorMeta(a)
_conj_doc = """
Returns a conjugated view of the original tensor
"""
conj = _make_prim(
schema="conj(Tensor(a) a) -> Tensor(a)",
meta=_conj_meta,
impl_aten=torch.conj,
return_type=RETURN_TYPE.VIEW,
doc=_conj_doc,
)
def expand_dims(
a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
) -> TensorLikeType:
"""
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
dimensions of length one at the dimensions specified by dimensions.
"""
if ndim is not None:
# TODO: this is only here to support the unsqueeze ref
dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
else:
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
if len(set(dims)) != len(dims):
msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
raise ValueError(msg)
new_shape = list(a.shape)
for idx in dims:
new_shape.insert(idx, 1)
broadcast_dimensions = [
idx for idx in range(len(new_shape)) if idx not in dimensions
]
return broadcast_in_dim(a, new_shape, broadcast_dimensions)
# Note: saves the Python slice object because we're about to clobber its name with the slice prim
pyslice: Type[slice] = slice # type: ignore[has-type]
def _slice_meta(
a: TensorLikeType,
start_indices: DimsSequenceType,
limit_indices: DimsSequenceType,
strides: Optional[StrideType] = None,
) -> TensorLikeType:
_strides = strides if strides is not None else [1] * len(start_indices)
if a.ndim != len(start_indices):
msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format(
a.ndim, len(start_indices)
)
raise ValueError(msg)
if a.ndim != len(limit_indices):
msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format(
a.ndim, len(limit_indices)
)
raise ValueError(msg)
if a.ndim != len(_strides):
msg = (
"Attempting to slice tensor of rank {0} with strides of length {1}!".format(
a.ndim, len(limit_indices)
)
)
raise ValueError(msg)
for x, y in zip(start_indices, a.shape):
if x < 0:
msg = "Attempting to slice a tensor with a negative start index of {0}!".format(
x
)
raise ValueError(msg)
if x > y:
msg = (
"Attempting to slice a tensor but a start index in {0} is greater than"
" the length of its corresponding dimension in shape {1}".format(
start_indices, a.shape
)
)
raise ValueError(msg)
for x, y, z in zip(limit_indices, a.shape, start_indices):
if x < 0:
msg = "Attempting to slice a tensor with a negative stop index of {0}!".format(
x
)
raise ValueError(msg)
if x > y:
msg = (
"Attempting to slice a tensor but a stop index in {0} is greater than the length of "
" its corresponding dimension in shape {1}".format(
limit_indices, a.shape
)
)
raise ValueError(msg)
if x < z:
msg = (
"Attempting to slice a tensor but a start index in {0} is greater than "
" its corresponding stop index {1}".format(x, z)
)
for x in _strides:
if x <= 0:
msg = (
"Attempting to slice a tensor with a non-positive step of {0}!".format(
x
)
)
raise ValueError(msg)
new_shape = []
for x, y, z in zip(start_indices, limit_indices, _strides):
new_shape.append(math.floor((y - x) / z))
new_strides = []
for x, y in zip(a.stride(), _strides):
new_strides.append(x * y)
return TensorMeta(a, shape=new_shape, strides=new_strides)
def _slice_aten(
a: Tensor,
start_indices: DimsSequenceType,
limit_indices: DimsSequenceType,
strides: Optional[StrideType] = None,
) -> Tensor:
_strides = strides if strides is not None else [1] * len(start_indices)
slices = []
for start, stop, step in zip(start_indices, limit_indices, _strides):
slices.append(pyslice(start, stop, step))
return operator.getitem(a, slices) # type: ignore[call-overload]
_slice_doc = """
Creates a view of a "bounding box" within the tensor.
The bounding box is specified independently in each of the tensor's dimensions.
start_indices and limit_indices describe the box's boundaries for their corresponding
dimensions. If strides is specified then they specify the step size between elements
in their corresponding dimension.
This operation is analogous to slicing in NumPy, but does not permit slices where
the stop indices are less than the start indices.
"""
slice = _make_prim(
schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
meta=_slice_meta,
impl_aten=_slice_aten,
return_type=RETURN_TYPE.VIEW,
doc=_slice_doc,
)
def _slice_in_dim_meta(
a: TensorLikeType,
start_index: int,
limit_index: int,
stride: int = 1,
axis: int = 0,
) -> TensorLikeType:
if axis < 0:
msg = "slice_in_dim: received a negative axis {0}".format(axis)
raise ValueError(msg)
if axis >= a.ndim:
msg = "slice_in_dim: axis {0} is greater or equal to the rank {1} of the tensor".format(
axis, a.ndim
)
raise ValueError(msg)
if start_index < 0:
msg = "slice_in_dim: received a negative start_index {0}".format(start_index)
raise ValueError(msg)
if start_index > a.shape[axis]:
msg = "slice_in_dim: start_index is greater than the length {0} of dimension {1}".format(
start_index, axis
)
raise ValueError(msg)
if limit_index > a.shape[axis]:
msg = "slice_in_dim: limit_index is greater than the length {0} of dimension {1}".format(
limit_index, axis
)
raise ValueError(msg)
if limit_index < start_index:
msg = "slice_in_dim: received a limit_index {0} less than the start_index {1}".format(
limit_index, start_index
)
raise ValueError(msg)
if stride < 0:
msg = "slice_in_dim: received a non-positive stride of {0}!".format(stride)
raise ValueError(msg)
start_indices = [0] * a.ndim
limit_indices = list(a.shape)
strides = [1] * a.ndim
start_indices[axis] = start_index
limit_indices[axis] = limit_index
strides[axis] = stride
return _slice_meta(a, start_indices, limit_indices, strides)
def _slice_in_dim_aten(
a: Tensor,
start_index: int,
limit_index: int,
stride: int = 1,
axis: int = 0,
) -> Tensor:
start_indices = [0] * a.ndim
limit_indices = list(a.shape)
strides = [1] * a.ndim
start_indices[axis] = start_index
limit_indices[axis] = limit_index
strides[axis] = stride
return slice(a, start_indices, limit_indices, strides)
_slice_in_dim_doc = """
Convenience wrapper for slicing just one dimension using slice.
"""
# TODO: make stride SymInt
slice_in_dim = _make_prim(
schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
meta=_slice_in_dim_meta,
impl_aten=_slice_in_dim_aten,
return_type=RETURN_TYPE.VIEW,
doc=_slice_in_dim_doc,
)
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
assert isinstance(a, TensorLike)
utils.validate_idx(a.ndim, dim)
utils.validate_dim_length(outer_length)
# Verifies the dim can be split with the specified lhs_length
inner_length = a.shape[dim] // outer_length
if (a.shape[dim] % outer_length) != 0:
msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
a.shape[dim], outer_length
)
raise ValueError(msg)
new_shape: List[int] = []
new_strides: List[int] = []
for idx in range(a.ndim):
if idx == dim:
new_shape.extend((outer_length, inner_length))
new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
else:
new_shape.append(a.shape[idx])
new_strides.append(a.stride()[idx])
return TensorMeta(a, shape=new_shape, strides=new_strides)
def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
inner_length = a.shape[dim] // outer_length
new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
return a.view(new_shape)
_split_dim_doc = """
Creates a view of a with the given dimension (of length l) split
into two dimensions, with the outer of the two having
length outer_length and the inner of the two having computed
length inner_length such outer_length * inner_length = l.
"""
# TODO: consider renaming split_dim_view
split_dim = _make_prim(
schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
meta=_split_dim_meta,
impl_aten=_split_dim_aten,
return_type=RETURN_TYPE.VIEW,
doc=_split_dim_doc,
)
# Note: allows dimensions to be specified redundantly
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
assert isinstance(a, TensorLike)
for idx in dimensions:
utils.validate_idx(a.ndim, idx)
assert a.shape[idx] == 1
new_shape = []
new_strides = []
for idx in range(len(a.shape)):
if idx in dimensions:
continue
new_shape.append(a.shape[idx])
new_strides.append(a.stride()[idx])
return TensorMeta(a, shape=new_shape, strides=new_strides)
def _squeeze_aten(a: Tensor, dimensions: Sequence) -> Tensor:
for idx in reversed(sorted(dimensions)):
a = torch.squeeze(a, dim=idx)
return a
_squeeze_doc = """
Creates a view of the tensor with the specified dimensions removed.
The removed dimensions must each have length one.
"""
squeeze = _make_prim(
schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
meta=_squeeze_meta,
impl_aten=_squeeze_aten,
return_type=RETURN_TYPE.VIEW,
doc=_squeeze_doc,
)
def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
if a.ndim != len(permutation):
msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format(
a.ndim, len(permutation)
)
raise ValueError(msg)
if not utils.is_valid_permutation(a.ndim, permutation):
msg = "Received an invalid permutation, {0}!".format(permutation)
raise ValueError(msg)
new_shape = [0] * a.ndim
new_strides = [0] * a.ndim
for idx, dim in enumerate(permutation):
new_shape[idx] = a.shape[dim]
new_strides[idx] = a.stride()[dim]
return TensorMeta(a, shape=tuple(new_shape), strides=tuple(new_strides))
def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
return torch.permute(a, permutation)
_transpose_doc = """
Creates a view of the tensor with its dimensions permuted.
The length of the permutation must be the rank of the tensor,
and each element of the permutation specifies the new order
for the corresponding dimension.
"""
transpose = _make_prim(
schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
meta=_transpose_meta,
impl_aten=_transpose_aten,
return_type=RETURN_TYPE.VIEW,
doc=_transpose_doc,
)
def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
return TensorMeta(a)
def _view_of_aten(a: Tensor) -> Tensor:
return a.view(a.shape)
_view_of_doc = """
Creates a view of the tensor.
"""
view_of = _make_prim(
schema="view_of(Tensor(a) a) -> Tensor",
meta=_view_of_meta,
impl_aten=_view_of_aten,
return_type=RETURN_TYPE.VIEW,
doc=_view_of_doc,
)
#
# Shape operations
#
def collapse(a: Tensor, start: int, end: int) -> Tensor:
"""
Wrapper around reshape that collapses a span of dimensions.
See collapse_view for the corresponding view operation.
"""
dim_length = 1
for idx in range(start, end):
dim_length = dim_length * a.shape[idx]
new_shape = a.shape[0:start] + (dim_length,) + a.shape[end:]
return reshape(a, new_shape)
# TODO: review stride logic
def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
# Verifies same shape (except in the concat dimension)
shape = tensors[0].shape
concat_length = 0
for tensor_idx, tensor in enumerate(tensors):
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
elif length != common_length:
raise RuntimeError(
f"Sizes of tensors must match except in dimension {dim}. "
"Expected {common_length} but got {length} for tensor number "
"{tensor_idx} in the list"
)
new_shape = list(tensors[0].shape).copy()
new_shape[dim] = concat_length
return TensorMeta(
tensors[0],
shape=new_shape,
strides=utils.make_contiguous_strides_for(new_shape),
)
def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
return torch.cat(tensors, dim)
_cat_doc = """
Concatenates tensors along the specified dimension.
The tensors' shapes must have the same rank and same length for other dimensions.
"""
cat = _make_prim(
schema="cat(Tensor[] tensors, int dim) -> Tensor",
meta=_cat_meta,
impl_aten=_cat_aten,
return_type=RETURN_TYPE.NEW,
doc=_cat_doc,
)
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
assert isinstance(a, TensorLike)
utils.validate_shape(shape)
# Validates the tensor and the requested shape have the
# same number of elements
numel = reduce(operator.mul, shape)
if numel != a.numel():
msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
a.numel(), numel
)
raise ValueError(msg)
return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
return a.reshape(shape).contiguous().clone()
_reshape_doc = """
Creates a contiguous tensor with the specified shape
containing a copy of the data in a.
"""
reshape = _make_prim(
schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
meta=_reshape_meta,
impl_aten=_reshape_aten,
return_type=RETURN_TYPE.NEW,
doc=_reshape_doc,
)
def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
utils.validate_dimension_indices(a.ndim, dims)
return TensorMeta(a)
_rev_doc = """
Reverses the order of elements along the given dimensions.
"""
rev = _make_prim(
schema="rev(Tensor a, int[] dims) -> Tensor",
meta=_rev_meta,
impl_aten=torch.flip,
return_type=RETURN_TYPE.NEW,
doc=_rev_doc,
)
#
# Conditional prims
#
def _where_meta(
pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
) -> TensorLikeType:
return _elementwise_meta(
a,
b,
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
args_with_fixed_dtypes=(pred,),
)
_where_doc = """
Selects elements from a and b according to pred.
Where pred is true the result contains the element from a, and
where pred is false the result contains the element from b.
"""
where = _make_prim(
schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
meta=_where_meta,
impl_aten=torch.where,
return_type=RETURN_TYPE.NEW,
doc=_where_doc,
)
#
# Type conversions
#
def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
# Type checks
assert isinstance(a, TensorLike)
assert isinstance(dtype, torch.dtype)
strides = utils.compute_elementwise_output_strides(a)
return TensorMeta(a, strides=strides, dtype=dtype)
def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
# Propagates requires grad when possible
if not utils.is_grad_dtype(dtype):
requires_grad = False
else:
# TODO: update meta objects so this can be acquired directly
try:
requires_grad = a.requires_grad
except Exception as e:
requires_grad = False
result = torch.empty_like(
a, device=a.device, dtype=dtype, requires_grad=requires_grad
)
with torch.no_grad():
return copy_to(result, a)
_convert_element_type_doc = """
Creates a copy of a tensor with the given dtype.
"""
convert_element_type = _make_prim(
schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
meta=_convert_element_type_meta,
impl_aten=_convert_element_type_aten,
return_type=RETURN_TYPE.NEW,
doc=_convert_element_type_doc,
)
def _device_put_meta(
a: TensorLikeType, device: Union[str, torch.device]
) -> TensorLikeType:
assert isinstance(a, TensorLike)
assert isinstance(device, (str, torch.device))
return TensorMeta(a, device=utils.canonicalize_device(device))
def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
return a.to(device)
_device_put_doc = """
Creates a copy of a tensor on the given device.
"""
device_put = _make_prim(
schema="device_put(Tensor a, Device device) -> Tensor",
meta=_device_put_meta,
impl_aten=_device_put_aten,
return_type=RETURN_TYPE.NEW,
doc=_device_put_doc,
)
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _item_meta(a: TensorLikeType) -> FakeTensor:
number_type = utils.dtype_to_type(a.dtype)
return TensorMeta(number_type(-1))
_item_doc = """
Converts a tensor with one element to a Python number.
"""
# TODO: create a new return type for scalars?
# FIXME: currently returns integers for boolean tensors
# https://github.com/pytorch/pytorch/issues/78071
item = _make_prim(
schema="item(Tensor a) -> Scalar",
meta=_item_meta,
impl_aten=torch.Tensor.item,
return_type=RETURN_TYPE.NEW,
doc=_item_doc,
)
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
number_type = utils.dtype_to_type(dtype)
return TensorMeta(number_type(-1))
def _maximum_value_aten(dtype: torch.dtype):
if dtype == torch.bool:
return True
elif dtype.is_complex or dtype.is_floating_point:
return torch.finfo(dtype).max
else:
return torch.iinfo(dtype).max
_maximum_value_doc = """
Return the maximum finite value for a dtype.
"""
# TODO: create a new return type for scalars?
# FIXME: currently returns integers for boolean tensors
# https://github.com/pytorch/pytorch/issues/78071
maximum_value = _make_prim(
schema="maximum_value(ScalarType dtype) -> Scalar",
meta=_maximum_value_meta,
impl_aten=_maximum_value_aten,
return_type=RETURN_TYPE.NEW,
doc=_maximum_value_doc,
)
# NOTE: need to model meta scalars
# See https://github.com/pytorch/pytorch/issues/78070
def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
number_type = utils.dtype_to_type(dtype)
return TensorMeta(number_type(-1))
def _minimum_value_aten(dtype: torch.dtype):
if dtype == torch.bool:
return False
elif dtype.is_complex or dtype.is_floating_point:
return torch.finfo(dtype).min
else:
return torch.iinfo(dtype).min
_minimum_value_doc = """
Return the mimimum finite value for a dtype.
"""
# TODO: create a new return type for scalars?
# FIXME: currently returns integers for boolean tensors
# https://github.com/pytorch/pytorch/issues/78071
minimum_value = _make_prim(
schema="minium_value(ScalarType dtype) -> Scalar",
meta=_minimum_value_meta,
impl_aten=_minimum_value_aten,
return_type=RETURN_TYPE.NEW,
doc=_minimum_value_doc,
)
# TODO: FIXME: strides are incorrect
def _to_dtype_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
strides = utils.make_contiguous_strides_for(a.shape)
return TensorMeta(a, strides=strides, dtype=dtype)
def _to_dtype_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
return a.to(dtype)
_to_dtype_doc = """
Creates a contiguous copy of a tensor with the given dtype.
"""
to_dtype = _make_prim(
schema=("to_dtype(Tensor a, ScalarType dtype) -> Tensor"),
meta=_to_dtype_meta,
impl_aten=_to_dtype_aten,
return_type=RETURN_TYPE.NEW,
doc=_to_dtype_doc,
)
#
# Inplace operators
#
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
assert isinstance(a, TensorLike)
assert isinstance(b, TensorLike)
# Validates the cast is safe
# TODO: move this as an option on the reference
# a_typ = utils.dtype_to_type(a.dtype)
# b_typ = utils.dtype_to_type(b.dtype)
# if a_typ is not utils.get_higher_type(a_typ, b_typ):
# raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")
# Validates the tensors have the same number of elements
if a.numel() != b.numel():
msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format(
b.numel(), a.numel()
)
raise RuntimeError(msg)
return a
def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
return a.copy_(b)
_copy_to_doc = """
Copies the data in b to a and returns the modified a.
"""
# TODO: Remove safe casting and implement on reference instead
copy_to = _make_prim(
schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
meta=_copy_to_meta,
impl_aten=_copy_to_aten,
return_type=RETURN_TYPE.INPLACE,
doc=_copy_to_doc,
)
def _resize_meta(a: TensorLikeType, shape: ShapeType):
return a.resize_(shape)
def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
return a.resize_(shape)
_resize_doc = """
Gives a tensor with no elements a new shape, returning the modified tensor.
The tensor's strides are contiguous and its values are unitialized.
"""
# TODO: review support arbitrary resizes
resize = _make_prim(
schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
meta=_resize_meta,
impl_aten=_resize_aten,
return_type=RETURN_TYPE.INPLACE,
doc=_resize_doc,
)
def _reduction_meta(inp, dims, *, output_dtype=None):
"""
Meta function for single output reduction operations
Stride logic is incorrect
"""
assert isinstance(inp, TensorLike)
if output_dtype is None:
output_dtype = inp.dtype
output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
return TensorMeta(
shape=output_shape,
strides=utils.make_contiguous_strides_for(output_shape),
dtype=output_dtype,
device=inp.device,
)
def _var_reduction_meta(inp, dims, *, correction):
if utils.is_complex_dtype(inp.dtype):
output_dtype = utils.corresponding_real_dtype(inp.dtype)
else:
output_dtype = inp.dtype
return _reduction_meta(inp, dims, output_dtype=output_dtype)
_sum_doc = """
Computes the sum of elements in the input tensor over the list of dimensions
specified in the dim argument
"""
_prod_doc = """
Computes the product of elements in the input tensor over the list of dimensions
specified in the dim argument
"""
_amax_doc = """
Computes the maximum value of elements in the input tensor over the list of dimensions
specified in the dim argument
"""
_amin_doc = """
Computes the minimum value of elements in the input tensor over the list of dimensions
specified in the dim argument
"""
_var_doc = """
Computes the biased variance of x over the list of dimensions specified in the dim argument
"""
def _make_reduction_prim(name: str, impl_aten, doc):
"""Creates a reduction prim."""
return _make_prim(
schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
meta=_reduction_meta,
impl_aten=impl_aten,
return_type=RETURN_TYPE.NEW,
doc=doc,
)
def _make_var_reduction_prim(name: str, impl_aten, doc):
"""Creates a reduction prim."""
return _make_prim(
schema=f"{name}(Tensor inp, int[]? dims, *, int correction, ScalarType? output_dtype=None) -> Tensor",
meta=_var_reduction_meta,
impl_aten=impl_aten,
return_type=RETURN_TYPE.NEW,
doc=doc,
)
sum = _make_reduction_prim(
name="sum",
impl_aten=torch.sum,
doc=_sum_doc,
)
def _prod_aten(
inp: TensorLikeType,
dims: Optional[DimsSequenceType],
*,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
if dims is not None:
for d in sorted(dims, reverse=True):
assert d >= 0
inp = torch.prod(inp, d, dtype=dtype)
return inp
else:
return torch.prod(inp, dims, dtype=dtype)
prod = _make_reduction_prim(
name="prod",
impl_aten=_prod_aten,
doc=_prod_doc,
)
var = _make_var_reduction_prim(
name="var",
impl_aten=torch.var,
doc=_var_doc,
)
amax = _make_reduction_prim(
name="amax",
impl_aten=torch.amax,
doc=_amax_doc,
)
amin = _make_reduction_prim(
name="amin",
impl_aten=torch.amin,
doc=_amin_doc,
)
_arange_doc = """
Constructs a 1-D tensor with values from the interval [start, end) taken
with common difference `step` beginning from `start`.
"""
# TODO: layout, pin_memory, memory_format
# TODO: model requires_grad on TensorMeta
def _arange_meta(
start: NumberType,
end: NumberType,
step: NumberType,
*,
dtype: Optional[torch.dtype],
device: Optional[torch.device],
requires_grad: bool,
) -> TensorLikeType:
assert not (
isinstance(start, complex)
and isinstance(end, complex)
and isinstance(step, complex)
)
utils.check(
step != 0,
lambda: "step must be nonzero",
)
utils.check(
math.isfinite(start) and math.isfinite(end),
lambda: f"unsupported range: {start} -> {end}",
)
utils.check(
(step > 0 and end >= start) or (step < 0 and end <= start),
lambda: "upper bound and lower bound inconsistent with step sign",
)
if dtype is not None:
pass
elif all(isinstance(arg, int) for arg in (start, end, step)):
dtype = torch.int64
else:
dtype = torch.get_default_dtype()
device = _get_default_device() if device is None else device
shape = (math.ceil((end - start) / step),)
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
def _arange_aten(
start: NumberType,
end: NumberType,
step: NumberType,
*,
dtype: Optional[torch.dtype],
device: Optional[torch.device],
requires_grad: bool,
) -> TensorLikeType:
# mypy: Not all union combinations were tried because there are too many unions
return torch.arange( # type: ignore[call-overload, misc]
start, # type: ignore[arg-type]
end, # type: ignore[arg-type]
step, # type: ignore[arg-type]
dtype=dtype,
device=device,
layout=torch.strided,
pin_memory=False,
requires_grad=requires_grad,
)
# TODO: maybe prims should not have requires_grad arg
# see: https://github.com/pytorch/pytorch/pull/77542/files#r873943255
arange = _make_prim(
schema="arange(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype, Device? device, bool requires_grad) -> Tensor", # noqa: B950
return_type=RETURN_TYPE.NEW,
meta=_arange_meta,
impl_aten=_arange_aten,
doc=_arange_doc,
)
# TODO: layout, pin_memory, memory_format
# TODO: model requires_grad on TensorMeta
def _empty_meta(
shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
) -> TensorLikeType:
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
def _empty_aten(
shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
) -> Tensor:
return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
_empty_doc = """
Creates a tensor with uninitialized values and the specified shape, dtype, and device.
"""
empty = _make_prim(
schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
meta=_empty_meta,
impl_aten=_empty_aten,
return_type=RETURN_TYPE.NEW,
doc=_empty_doc,
)
def _empty_strided_meta(
shape: ShapeType,
strides: StrideType,
*,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
_empty_strided_doc = """
Creates a tensor with uninitialized values.
"""
# TODO: add layout, pin_memory
empty_strided = _make_prim(
schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
return_type=RETURN_TYPE.NEW,
meta=_empty_strided_meta,
impl_aten=torch.empty_strided,
doc=_empty_strided_doc,
)
def _full_meta(
shape: ShapeType,
fill_value: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
def _full_aten(
shape: ShapeType,
fill_value: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> Tensor:
# Note that Mypy thinks torch.full can't accept a complex fill_value
return torch.full(
shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
)
_full_doc = """
Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
"""
# TODO: add layout
full = _make_prim(
schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
meta=_full_meta,
impl_aten=_full_aten,
return_type=RETURN_TYPE.NEW,
doc=_full_doc,
)
def _full_like_meta(
a: TensorLikeType,
fill_value: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
strides = utils.compute_elementwise_output_strides(a)
if a.numel() == 0:
strides = a.stride()
return TensorMeta(a, strides=strides, dtype=dtype, device=device)
def _full_like_aten(
a: Tensor,
fill_value: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> Tensor:
# Note that Mypy thinks torch.full can't accept a complex fill_value
return torch.full_like(
a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
)
_full_like_doc = """
Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
given tensor by default. The dtype and device settings can be overridden
by specifying them explicitly.
"""
full_like = _make_prim(
schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
meta=_full_like_meta,
impl_aten=_full_like_aten,
return_type=RETURN_TYPE.NEW,
doc=_full_like_doc,
)
def _scalar_tensor_meta(
scalar: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
) -> TensorLikeType:
shape: ShapeType = []
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
def _scalar_tensor_aten(
scalar: NumberType,
*,
dtype: torch.dtype,
device: torch.device,
) -> Tensor:
if isinstance(scalar, complex) and (
dtype is None or not utils.is_complex_dtype(dtype)
):
raise TypeError("Complex scalar requires complex tensor dtype.")
# Note that Mypy thinks torch.scalar can't accept a complex scalar
return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type]
_scalar_tensor_doc = """
Wraps a Number into a Tensor with the specified dtype and device.
"""
# TODO: add layout and pin_memory support
scalar_tensor = _make_prim(
schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
meta=_scalar_tensor_meta,
impl_aten=_scalar_tensor_aten,
return_type=RETURN_TYPE.NEW,
doc=_scalar_tensor_doc,
)
#
# Linear algebra (linalg) prims
#
def _svd_meta(
A: TensorLikeType, *, full_matrices: bool
) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
utils.check_is_matrix(A, "linalg.svd")
utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
A_shape = A.shape
batch = A_shape[:-2]
m, n = A_shape[-2:]
k = min(m, n)
shape_U = batch + (m, m if full_matrices else k)
strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
shape_S = batch + (k,)
strides_S = utils.make_contiguous_strides_for(shape_S)
S = TensorMeta(
shape=shape_S,
strides=strides_S,
dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
device=A.device,
)
shape_Vh = batch + (n if full_matrices else k, n)
# The CPU backend returns V, but the cuSolver backend returns V^H
# TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend
is_cuda = A.device.type == "cuda"
strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
return U, S, Vh
def _svd_aten(
A: TensorLikeType, *, full_matrices: bool
) -> Tuple[Tensor, Tensor, Tensor]:
return torch.linalg.svd(A, full_matrices=full_matrices)
_svd_doc = """
Returns the SVD of a matrix or batch of matrices.
The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
"""
svd = _make_prim(
schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
meta=_svd_meta,
impl_aten=_svd_aten,
return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
doc=_svd_doc,
)
#
# Randomness Prims
#
# TODO: add generator support
# NOTE: there is currently no way of acquiring the "default" torch generator
def _normal_meta(
shape: ShapeType,
*,
mean: Union[float, complex],
std: float,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
utils.check(
std >= 0.0,
lambda: f"expected non-negative standard deviation, but got std={std}",
)
utils.check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
)
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
def _normal_aten(
shape: ShapeType,
*,
mean: Union[float, complex],
std: float,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
) -> Tensor:
a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
with torch.no_grad():
# NOTE: normal_ is incorrectly annotated to expect mean to be a float
a.normal_(mean, std) # type: ignore[arg-type]
return a
_normal_doc = """
Constructs a tensor filled with values drawn from a normal distribution with the specified mean
and standard deviation.
Only supports floating-point types.
"""
normal = _make_prim(
schema=(
"normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad) -> Tensor"
),
return_type=RETURN_TYPE.NEW,
meta=_normal_meta,
impl_aten=_normal_aten,
doc=_normal_doc,
)
def _uniform_meta(
shape: ShapeType,
*,
low: float,
high: float,
dtype: torch.dtype,
device: torch.device,
) -> TensorLikeType:
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
def _uniform_aten(
shape: ShapeType,
*,
low: float,
high: float,
dtype: torch.dtype,
device: torch.device,
) -> Tensor:
a = torch.empty(shape, dtype=dtype, device=device)
a.uniform_(low, high)
return a
_uniform_doc = """
Constructs a tensor filled with values drawn uniformly from low to high.
"""
# TODO: we should more seriously review randomness modeling and prims
uniform = _make_prim(
schema=(
"uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor"
),
return_type=RETURN_TYPE.NEW,
meta=_uniform_meta,
impl_aten=_uniform_aten,
doc=_uniform_doc,
)
#
# FFT prims
#
def _fft_r2c_meta(
input: TensorLike,
*,
dim: DimsSequenceType,
onesided: bool,
) -> TensorLikeType:
dim = utils.canonicalize_dims(input.ndim, dim)
utils.validate_no_repeating_dims(dim)
shape = list(input.shape)
if onesided:
last_dim = dim[-1]
shape[last_dim] = shape[last_dim] // 2 + 1
dtype = utils.corresponding_complex_dtype(input.dtype)
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
def _fft_r2c_aten(
input: TensorLike,
*,
dim: DimsSequenceType,
onesided: bool,
) -> TensorLikeType:
normalization = 0 # No normalization
return torch._fft_r2c(input, dim, normalization, onesided)
_fft_r2c_doc = """
Performs a real to complex Fast Fourier Transform
"""
fft_r2c = _make_prim(
schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
meta=_fft_r2c_meta,
impl_aten=_fft_r2c_aten,
return_type=RETURN_TYPE.NEW,
doc=_fft_r2c_doc,
)
def _fft_c2c_meta(
input: TensorLike,
*,
dim: DimsSequenceType,
forward: bool,
) -> TensorLikeType:
dim = utils.canonicalize_dims(input.ndim, dim)
utils.validate_no_repeating_dims(dim)
shape = input.shape
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(
shape=shape, strides=strides, dtype=input.dtype, device=input.device
)
def _fft_c2c_aten(
input: TensorLike,
*,
dim: DimsSequenceType,
forward: bool,
) -> TensorLikeType:
normalization = 0 # No normalization
return torch._fft_c2c(input, dim, normalization, forward)
_fft_c2c_doc = """
Performs either a Fast Fourier Transform, or its inverse
"""
fft_c2c = _make_prim(
schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
meta=_fft_c2c_meta,
impl_aten=_fft_c2c_aten,
return_type=RETURN_TYPE.NEW,
doc=_fft_c2c_doc,
)
def _fft_c2r_meta(
input: TensorLike,
*,
dim: DimsSequenceType,
last_dim_size: int,
) -> TensorLikeType:
dim = utils.canonicalize_dims(input.ndim, dim)
utils.validate_no_repeating_dims(dim)
shape = list(input.shape)
shape[dim[-1]] = last_dim_size
dtype = utils.corresponding_real_dtype(input.dtype)
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
def _fft_c2r_aten(
input: TensorLike,
*,
dim: DimsSequenceType,
last_dim_size: int,
) -> TensorLikeType:
normalization = 0 # No normalization
return torch._fft_c2r(input, dim, normalization, last_dim_size)
_fft_c2r_doc = """
Performs a complex to real Inverse Fast Fourier Transform
"""
fft_c2r = _make_prim(
schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
meta=_fft_c2r_meta,
impl_aten=_fft_c2r_aten,
return_type=RETURN_TYPE.NEW,
doc=_fft_c2r_doc,
)
register_nvprims()
|