1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2977 2978 2979 2980 2981 2982 2983 2984 2985 2986 2987 2988
|
# Defines derivative formulas and Python signatures of methods on Variable
#
# Note about possibly confusing nomenclature: An 'output gradient' is the
# gradient of an output of a forward function. Output gradients are used as
# the inputs to backward functions. `grads` is a vector of output gradients,
# and `grad == grads[0]`, in all the derivative formulas in this file.
# An 'input gradient' is the gradient of an input to a forward function.
# Input gradients are the outputs of backward functions, corresponding to the
# input names included in the derivative formulas defined in this file.
# Also, every time we talk computing "gradient" we actually mean computing
# the vector jacobian product using the given 'output gradient' as the vector.
#
# Each entry consists of:
# - A 'name', which specifies the ATen name of the function you
# are defining derivatives for, and an argument specification.
# - An optional 'dispatch' entry which can be used to specify
# per-autograd dispatch key derivatives. If this entry is not
# specified, then the gradient entries will be taken as the
# default gradients (i.e. registered for every backward dispatch
# key). (see _test_autograd_multiple_dispatch for an example
# of how to register separate derivates for different dispatch keys).
# The list of allowed dispatch keys (in addition to 'Default' which
# represents the Autograd alias key) is torchgen/model.py:AUTOGRAD_KEYS.
# - One or more gradients entries, mapping differentiable input
# names to a formula specifying how to compute its gradient.
# Note that a single gradient entry can specify the gradient
# formula for multiple input names, by specifying a key
# "input1, input2" (see atan2 for an example).
# - An argument can be flagged as 'non_differentiable'.
# - Optional entry with key 'output_differentiability' and value a list of the
# same length as the number of outputs from the forward function. The list
# should contain only booleans, specifying whether each of the output Tensor
# is differentiable.
# If it is not specified for a function that returns multiple elements but
# uses `grad` instead of `grads[idx]`, then all but the first output will
# be marked as non-differentiable.
# If None of the output is differentiable, you can also add the function
# name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list.
#
# There are two cases for Tensor and TensorList arguments here:
# - If that argument is differentiable, in the sense that a gradient with respect
# to that argument could exist. You should either:
# - Specify the formula for that gradient
# - Specify not_implemented("function_name") as a formula to say that this is not
# implement yet (but might be in the future and the user can request that on an issue)
# - If that argument is not differentiable, because it is not a floating point dtype or the
# function is not differentiable with respect to that argument for
# example. You should either:
# - Do not specify any formula for this argument
# - Specify explicitly that this argument is "non_differentiable". Note that in this case,
# we trust you that this argument will never have requires_grad=True and it will be silently
# ignored if it does.
#
# If a function has out-of-place and in-place variants, then the derivative
# definition for the in-place variant is optional. It will default to the
# definition for the out-of-place variant. Note that _out variants are never
# differentiable.
#
# Gradient expressions are standard C++ expressions operating on ATen
# variables. In a gradient expression, the following variables are in
# scope:
#
# - 'grad', the gradient of the output (often spelled grad_output
# in Python) which we are going to left-multiply.
#
# When a function returns multiple *differentiable* outputs,
# you can refer to the gradients of each outputs using 'grads',
# e.g., 'grads[0]', 'grads[1]'.
#
# When a function returns multiple *differentiable* outputs that
# are named, you can refer to the gradients of each outputs using
# 'grad_{name}', e.g., 'grad_x', 'grad_y'.
#
# When a function returns *one* differentiable output (the
# first output) and some more nondifferentiable outputs,
# you MUST refer to the gradient of the differentiable output with
# 'grad' (this case is special-cased in our code generation).
#
# Note that the number of differentibale outputs can be modified by the
# 'output_differentiability' entry (see above).
#
# Across a differentiable function's derivatives set, it is not
# permitted to mix the use of "grad", "grads", and
# "grad_{name}". You must be consistent for that differentiable
# function.
#
# - Any of the input arguments, tensor or non-tensor, including
# argument names that only appear in Declarations.yaml, e.g. 'output'.
#
# - 'result', representing the result of evaluating the forward
# expression for ATen native function declarations. If the forward
# expression outputs a tuple, use 'resultX' instead to access the
# X-th entry
#
# - 'grad_input_mask', a std::array<bool, n>, specifies which input
# gradients are actually needed. For example, in the entry
# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size
# two array, where `grad_input_mask[0]` is true if `input0` requires
# grad, and `grad_input_mask[1]` is true if `input1` requires grad.
#
# (NB: if your function computes gradient for a list of tensors,
# the `grad_input_mask` will only have a single entry for the list
# specifying if either zero or at least one tensor from the list requires
# grad. If we want to support more fine-grained signalling,
# we'll need some alternate variable which is not a std::array)
#
# - 'retain_variables', a bool which is true if a user has specified
# that saved variables should be retained in case the backwards is
# run again later. This allows an optimization where we can
# destroy saved buffers if we know variables are not going to be retained,
# e.g., it is used by _cudnn_rnn
#
# If you need a complex expression, e.g., with local variables,
# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp
# and invoke it from here. By the way, go read
# https://github.com/zdevito/ATen/issues/163; this describes an
# important hazard that occurs when porting backwards from Python to C++
#
# Double backwards gradient expressions can be somewhat confusing;
# the most important thing to remember is: (1) you need to define a
# derivative formula for every input, including inputs named things
# like 'grad_output', and (2) the gradient to multiply with is always
# called 'grad' (even though it really is a grad-grad).
#
# You can also add forward derivative definition by defining a formula for
# a returned value (in general "result" if the name is not specified). This
# formula works the same way as the backward one and advanced implementations
# should also be placed in the FunctionsManual file.
# This formula should compute a single Jacobian vector product using the (primal)
# value of the argument "foo_p", its forward grad "foo_t" and the result of the
# function as "result".
# Note that the forward derivative can be automatically generated in two cases:
# - if your function is linear (NOT affine or multi-linear), then you can
# specify so by just using the string "auto_linear" for the formula.
# - if your function is applied element wise (and has a single input), you
# can specify so by just using the string "auto_element_wise" for the formula.
#
# Note that to avoid unpacking overhead, functions taking TensorList as inputs
# will always have their forward grad formula called. This function is responsible
# to check if any computation is needed and should return an undefined Tensor when
# there is nothing to do. You can check "cat_forward" for a full example.
#
# NB: There are a number of gradient definitions in here which are bogus
# (implemented using zeros_like). These gradients are (hopefully) not
# used by our frontend. You MUST check the frontend code; search for
# OpName.apply to see if it's still using a legacy Python style API.
#
# Note: Returning views.
# The following cases exist:
# - If a function returns no view, it can have arbitrary outputs.
# - If a function return at least one Tensor that is a differentiable view
# of one of its input:
# - If there is only one differentiable output, this Tensor is marked as a
# differentiable view. (alias or transpose for example)
# - If there are more than one differentiable output, by default all the views are
# marked as differentiable views and created with allow_rebase_history=false.
# Meaning that any inplace operation on it will raise an error. (unbind for example)
#
# Notes about undefined output gradients:
# All backward functions must support all combinations of undefined output
# gradient Tensors, where `grad[i].defined() == false`. Depending on the
# number of input and output grads your derivative formula uses, code
# generation may automatically add some level of undefined grad support,
# according to these three cases:
#
# * 1 input grad and 1 output grad:
# Complete undefined grad support is automatically added, so you
# shouldn't have to think about it, unless there is a bug in the code
# generation.
#
# * 1 input grad and multiple output grads:
# Undefined grad support is automatically added ONLY in the case where
# all output grads are undefined. You will have to add explicit support
# for cases where a subset of output grads is undefined.
#
# * multiple input grads:
# No automatic support, so you will need to add it.
#
# If your derivative formula uses more than one output grad, it is usually
# preferable to add undefined grad support in the backward function itself
# (if you're using one), rather than in the derivative formula in this file.
#
# Undefined Tensors are created with the default constructor `at::Tensor()`.
# It is an efficient way to represent a Tensor filled with zeros because
# the Tensor holds no sizing information and no Storage data is allocated.
# But consequentially, Tensor operations cannot be performed on them.
# Therefore, your backward function should treat an undefined output grad as
# a zero, and it needs to be a special case.
#
# If all output grads are undefined, then it should be correct for the
# backward function to return undefined input grads. Since we use the chain
# rule, output grads equal to zero should result in input grads equal to zero,
# unless there is some rare special case.
#
# If a subset of output grads is undefined, then it may be acceptable for
# the backward function to return undefined input grads--it depends on the
# specific function, so you'll have to determine that yourself. If returning
# an undefined Tensor is correct for a given input grad, it is also logically
# correct to return a defined grad full of zeros, but that would not be
# preferable since it would be less efficient.
#
# NB: The parameter names here MUST be consistent with the parameter names
# in Decalarations.yaml
- name: abs(Tensor self) -> Tensor
self: grad * self.sgn()
result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn())
- name: acos(Tensor self) -> Tensor
self: grad * -((-self * self + 1).rsqrt()).conj()
result: auto_element_wise
- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj()))
result: self_t + maybe_multiply(other_t, alpha)
- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
result: self_t.clone()
- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta.conj())
batch1: maybe_multiply(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj())
batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })), alpha.conj())
result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha)
- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj())
tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj())
result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value)
- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj())
tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj())
result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value)
- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta.conj())
mat1: mm_mat1_backward(grad, mat2, mat1.sym_sizes(), mat1.sym_strides(), mat1.layout(), alpha)
mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha)
result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha)
- name: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta)
mat1: mm_mat1_sparse_backward(grad, mat1, mat2, alpha)
mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha)
- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta.conj())
mat: maybe_multiply(grad.ger(vec.conj()), alpha.conj())
vec: maybe_multiply(mat.t().conj().mv(grad), alpha.conj())
result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha)
- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta.conj())
vec1: maybe_multiply(grad.mv(vec2.conj()), alpha.conj())
vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj())
result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha)
- name: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor
theta: affine_grid_generator_backward(grad, size, align_corners)
- name: alias(Tensor(a) self) -> Tensor(a)
self: grad
result: self_t
- name: angle(Tensor self) -> Tensor
self: angle_backward(grad, self)
result: handle_r_to_c(result.scalar_type(), angle_backward(self_t.conj(), self_p).conj())
# The four items below are necessary because TensorIterator doesn't work on
# Variables (codegen does not unwrap the input Tensor for all() and any() ).
- name: any(Tensor self) -> Tensor
output_differentiability: [False]
- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
output_differentiability: [False]
- name: all(Tensor self) -> Tensor
output_differentiability: [False]
- name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
output_differentiability: [False]
- name: acosh(Tensor self) -> Tensor
# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case)
self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()"
result: auto_element_wise
- name: acosh_(Tensor(a!) self) -> Tensor(a!)
self: not_implemented("inplace version of acosh")
- name: asinh(Tensor self) -> Tensor
self: grad * (self.pow(2) + 1).rsqrt().conj()
result: auto_element_wise
- name: asinh_(Tensor(a!) self) -> Tensor(a!)
self: not_implemented("inplace version of asinh")
- name: atanh(Tensor self) -> Tensor
self: grad * 1 / (1 - self.pow(2)).conj()
result: auto_element_wise
- name: atanh_(Tensor(a!) self) -> Tensor(a!)
self: not_implemented("inplace version of atanh")
- name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
result: auto_linear
- name: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
result: auto_linear
- name: asin(Tensor self) -> Tensor
self: grad * (-self * self + 1).rsqrt().conj()
result: auto_element_wise
- name: atan(Tensor self) -> Tensor
self: grad / (self * self + 1).conj()
result: auto_element_wise
- name: atan2(Tensor self, Tensor other) -> Tensor
self, other: atan2_backward(grad, self, other, grad_input_mask)
result: (-self_p * other_t + other_p * self_t) / (self_p.pow(2) + other_p.pow(2))
- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta.conj())
batch1: maybe_multiply(grad.bmm(batch2.transpose(1, 2).conj()), alpha.conj())
batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj())
result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha)
- name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
- name: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
p: zeros_like(p)
result: self_t.zero_()
- name: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: bmm(Tensor self, Tensor mat2) -> Tensor
self: grad.bmm(mat2.transpose(1, 2).conj())
mat2: self.transpose(1, 2).conj().bmm(grad)
result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t)
- name: matmul(Tensor self, Tensor other) -> Tensor
self, other: matmul_backward(grad, self, other, grad_input_mask)
- name: cat(Tensor[] tensors, int dim=0) -> Tensor
tensors: cat_tensors_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors), dim)
result: cat_jvp(tensors, dim)
- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: ceil(Tensor self) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
- name: cholesky(Tensor self, bool upper=False) -> Tensor
self: cholesky_backward(grad, upper, result)
- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
self: cholesky_backward(grad, upper, L)
L: cholesky_jvp(self_t, L, upper)
- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
self, input2: cholesky_solve_backward(grad, self, input2, result, upper)
result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper)
- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor
self: cholesky_inverse_backward(grad, self, upper, result)
result: cholesky_inverse_jvp(self_p, self_t, result, upper)
# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
- name: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
self: clamp_backward(grad, self, min, max)
min, max: clamp_backward_min_max(grad, self, min, max, grad_input_mask)
result: clamp_jvp(self_p, self_t, min_p, min_t, max_p, max_t)
- name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
self: clamp_backward(grad, self, min, max)
result: auto_element_wise
- name: clamp_min(Tensor self, Scalar min) -> Tensor
self: where(self >= min, grad, at::scalar_tensor(0., grad.options()))
result: auto_element_wise
- name: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
self: where(self >= min, grad, at::scalar_tensor(0., grad.options()))
min: where(self < min, grad, at::scalar_tensor(0., grad.options()))
result: where(self_p >= min_p, self_t, min_t)
- name: clamp_max(Tensor self, Scalar max) -> Tensor
self: where(self <= max, grad, at::scalar_tensor(0., grad.options()))
result: auto_element_wise
- name: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor
self: where(self <= max, grad, at::scalar_tensor(0., grad.options()))
max: where(self > max, grad, at::scalar_tensor(0., grad.options()))
result: where(self_p <= max_p, self_t, max_t)
- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
self: grad
result: auto_linear
- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
self: _to_copy_backward(grad, self.options())
result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format)
# The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype)
# (If dtype IS nullopt, we rely on the regular check that any input requires grad).
output_differentiability: ["!dtype || isDifferentiableType(*dtype)"]
- name: _coalesce(Tensor self) -> Tensor
self: grad
- name: complex(Tensor real, Tensor imag) -> Tensor
real: at::real(grad)
imag: at::imag(grad)
result: at::complex(real_t, imag_t)
- name: polar(Tensor abs, Tensor angle) -> Tensor
abs, angle: polar_backward(grad, result)
result: at::complex(abs_t*angle_p.cos() - angle_t*abs_p*angle_p.sin(), abs_t*angle_p.sin() + angle_t*abs_p*angle_p.cos())
- name: _conj(Tensor(a) self) -> Tensor(a)
self: grad.conj()
result: self_t.conj()
- name: _neg_view(Tensor(a) self) -> Tensor(a)
self: grad.neg()
result: self_t._neg_view()
- name: _conj_physical(Tensor self) -> Tensor
self: grad.conj_physical()
result: self_t.conj_physical()
- name: conj_physical_(Tensor(a!) self) -> Tensor(a!)
self: grad.conj_physical()
result: self_t.conj_physical_()
- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)
other: zeros_like(other)
result: copysign_tensor_self_backward(self_t, self_p, result)
- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)
result: auto_element_wise
- name: cos(Tensor self) -> Tensor
self: grad * -self.sin().conj()
result: auto_element_wise
- name: cosh(Tensor self) -> Tensor
self: grad * self.sinh().conj()
result: auto_element_wise
- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
output_differentiability: [False]
- name: count_nonzero(Tensor self, int? dim=None) -> Tensor
output_differentiability: [False]
- name: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
self: at::linalg_cross(other.conj(), grad, dim)
other: at::linalg_cross(grad, self.conj(), dim)
result: "at::linalg_cross(self_t, other_p, dim) + at::linalg_cross(self_p, other_t, dim)"
- name: logcumsumexp(Tensor self, int dim) -> Tensor
self: logcumsumexp_backward(grad, self, result, dim)
- name: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
self: cumprod_backward(grad.to(self.scalar_type()), self, dim, result)
result: "cumprod_jvp(self_t, self_p, result, dim).to(dtype.has_value() ? *dtype : self_p.scalar_type())"
- name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
self: cumsum_backward(grad.to(self.scalar_type()), dim)
result: auto_linear
- name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)
self: cummaxmin_backward(grad, self, indices, dim)
values: self_t.gather(dim, indices)
- name: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)
self: cummaxmin_backward(grad, self, indices, dim)
values: self_t.gather(dim, indices)
- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor
self, weight, bias: "grad.defined() ? conv_tbc_backward(grad, self, weight, bias, pad) : std::tuple<Tensor, Tensor, Tensor>()"
- name: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity)
- name: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity)
- name: deg2rad(Tensor self) -> Tensor
self: deg2rad_backward(grad)
result: auto_element_wise
- name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)
A: linalg_det_backward(grad, result, A, LU, pivots)
result: linalg_det_jvp(A_t, result, LU, pivots, A_p.is_contiguous() && !A_p.is_complex())
output_differentiability: [True, False, False]
- name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)
A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots)
sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex())
output_differentiability: [True, True, False, False]
- name: block_diag(Tensor[] tensors) -> Tensor
tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors))
result: block_diag_jvp(tensors)
- name: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor
self: grad.diagonal(offset, dim1, dim2)
result: auto_linear
- name: diag(Tensor self, int diagonal=0) -> Tensor
self: diag_backward_symint(grad, self.sym_sizes(), diagonal)
result: auto_linear
- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2)
result: auto_linear
- name: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor
grad_output: grad.diagonal(offset, dim1, dim2)
result: auto_linear
- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
self: norm_backward(grad, self - other, p, result)
other: -norm_backward(grad, self - other, p, result)
result: norm_jvp(self_p - other_p, self_t - other_t, p, result, {}, false)
# The backward formula is done in this order to improve numerical stability
# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414
# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later
- name: div.Tensor(Tensor self, Tensor other) -> Tensor
self: div_tensor_self_backward(grad, other, self.scalar_type())
other: div_tensor_other_backward(grad, self, other)
result: (self_t - other_t * result) / other_p
- name: div.Scalar(Tensor self, Scalar other) -> Tensor
self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type())
result: self_t / other
- name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode)
other: div_tensor_other_backward(grad, self, other, rounding_mode)
result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p"
- name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type(), rounding_mode)
result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other"
- name: dot(Tensor self, Tensor tensor) -> Tensor
self: grad * tensor.conj()
tensor: grad * self.conj()
result: at::dot(self_t, tensor_p) + at::dot(self_p, tensor_t)
- name: vdot(Tensor self, Tensor other) -> Tensor
self: grad.conj() * other
other: grad * self
result: at::vdot(self_t, other_p) + at::vdot(self_p, other_t)
- name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
self: _fused_dropout_backward(grad, result1, p)
- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))"
result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t"
- name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor
grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)"
mask: 'not_implemented("native_dropout_backward: mask")'
- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
result: self_t.zero_()
- name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
self: zeros_like(self)
other: zeros_like(other)
result: self_t.zero_()
- name: erf(Tensor self) -> Tensor
self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
result: auto_element_wise
- name: erfc(Tensor self) -> Tensor
self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
result: auto_element_wise
- name: special_erfcx(Tensor self) -> Tensor
self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad
result: auto_element_wise
- name: erfinv(Tensor self) -> Tensor
self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad
result: auto_element_wise
- name: exp(Tensor self) -> Tensor
self: grad * result.conj()
result: auto_element_wise
- name: exp2(Tensor self) -> Tensor
self: grad * result * M_LN2
result: auto_element_wise
- name: expm1(Tensor self) -> Tensor
self: grad * (result + 1)
result: auto_element_wise
# TODO: this derivative is not SymInt safe, need sum_to support
- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
self: at::sum_to(grad, self.sym_sizes())
result: auto_linear
- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)
- name: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)
- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max, grad_factor) : std::tuple<Tensor, Tensor, Tensor>()"
- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask)
- name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor) : std::tuple<Tensor, Tensor, Tensor>()"
- name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)
self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)
- name: fill.Scalar(Tensor self, Scalar value) -> Tensor
self: zeros_like(grad)
result: at::fill(self_t, 0)
- name: fill.Tensor(Tensor self, Tensor value) -> Tensor
self: zeros_like(grad)
value: grad.sum()
result: at::fill(self_t, value_t)
- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.fill_(0)
- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
self: zeros_like(grad)
value: grad.sum()
result: self_t.fill_(value_t)
- name: floor(Tensor self) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
- name: fmod.Scalar(Tensor self, Scalar other) -> Tensor
self: grad
result: auto_element_wise
- name: fmod.Tensor(Tensor self, Tensor other) -> Tensor
self: grad
other: -grad * self.div(other, /*rounding_mode=*/"trunc")
result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"trunc")
- name: frac(Tensor self) -> Tensor
self: grad
result: self_t
- name: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
self: grad / exponent.exp2()
mantissa: self_t / exponent.exp2()
- name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
self: gather_backward(grad, self, dim, index, sparse_grad)
index: non_differentiable
result: auto_linear
- name: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
result: self_t.zero_()
- name: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
self: zeros_like(self)
other: zeros_like(other)
result: self_t.zero_()
- name: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: geqrf(Tensor self) -> (Tensor a, Tensor tau)
self: not_implemented("geqrf")
- name: indices(Tensor(a) self) -> Tensor(a)
output_differentiability: [False]
- name: _indices(Tensor(a) self) -> Tensor(a)
output_differentiability: [False]
- name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple<Tensor, Tensor>()"
- name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple<Tensor, Tensor>()"
# See NOTE [ grid_sample CPU fallback ]
- name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple<Tensor, Tensor>()"
- name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
result: self_t.zero_()
- name: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
self: zeros_like(self)
other: zeros_like(other)
result: self_t.zero_()
- name: hardsigmoid(Tensor self) -> Tensor
self: hardsigmoid_backward(grad, self)
result: auto_element_wise
- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
output_differentiability: [False]
- name: hardswish(Tensor self) -> Tensor
self: hardswish_backward(grad, self)
result: auto_element_wise
- name: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
grad_output: hardswish_backward(grad, self)
self: at::where(at::logical_and(-3.0 < self, self < 3.0), grad * grad_output / 3.0, at::zeros({}, self.options()))
result: "hardswish_backward(grad_output_t, self_p)
+ at::where(at::logical_and(-3.0 < self_p, self_p < 3.0), self_t * grad_output_p / 3.0, at::zeros({}, self_p.options()))"
- name: hypot(Tensor self, Tensor other) -> Tensor
self: grad * self / result
other: grad * other / result
result: self_t * self_p / result + other_t * other_p / result
- name: i0(Tensor self) -> Tensor
self: grad * at::special_i1(self)
result: auto_element_wise
- name: special_i0e(Tensor self) -> Tensor
self: grad * (at::special_i1e(self) - self.sgn() * result)
result: auto_element_wise
- name: special_i1(Tensor self) -> Tensor
self: i1_backward(grad, self, result)
result: auto_element_wise
- name: special_i1e(Tensor self) -> Tensor
self: i1e_backward(grad, self, result)
result: auto_element_wise
- name: igamma(Tensor self, Tensor other) -> Tensor
self: 'not_implemented("igamma: input")'
other: grad * exp((self - 1) * log(other) - other - lgamma(self))
- name: igammac(Tensor self, Tensor other) -> Tensor
self: 'not_implemented("igammac: input")'
other: -grad * exp((self - 1) * log(other) - other - lgamma(self))
- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
self: index_backward(grad.new_zeros(self.sizes(), self.options()), indices, grad)
result: auto_linear
- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
self: grad
# The case source.dim() == 0 is necessary to support scalar tensors of the form
# source.dim() == 0 and index.dim() == 1 and index.size() == (1,),
# This is because source is not broadcastable to index, as source.dim() < index.dim()
source: "maybe_multiply(source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0)), alpha)"
index: non_differentiable
result: at::index_add(self_t, dim, index, maybe_multiply(source_t, alpha))
- name: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor
self, source: index_reduce_backward(grad, self, dim, index, source, reduce, include_self, result)
index: non_differentiable
- name: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
self: grad.index_fill(dim, index, 0)
# The case source.dim() == 0 is necessary to support scalar tensors of the form
# source.dim() == 0 and index.dim() == 1 and index.size() == (1,),
# This is because source is not broadcastable to index, as source.dim() < index.dim()
source: "source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0))"
index: non_differentiable
result: self_t.index_copy(dim, index, source_t)
- name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
self: grad.index_fill(dim, index, 0)
index: non_differentiable
result: self_t.index_fill(dim, index, 0)
- name: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
self: grad.index_fill(dim, index, 0)
value: grad.index_select(dim, std::get<0>(at::_unique(index, /*sorted=*/false))).sum()
index: non_differentiable
result: self_t.index_fill(dim, index, value_t)
- name: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)"
values: grad.index(indices)
result: self_t.index_put(indices, values_t, accumulate)
- name: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)"
values: grad.index(indices)
result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe)
- name: index_select(Tensor self, int dim, Tensor index) -> Tensor
self: index_select_backward(grad, self.sizes(), dim, index)
index: non_differentiable
result: auto_linear
- name: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
A: -at::matmul(inverse.mH(), at::matmul(grad, inverse.mH()))
inverse: -at::matmul(at::matmul(inverse, A_t), inverse)
output_differentiability: [True, False]
- name: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor
self: pinv_backward(grad, result, self)
result: pinv_jvp(self_p, result, self_t)
- name: isnan(Tensor self) -> Tensor
self: non_differentiable
- name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
result: self_t.zero_()
- name: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
self: zeros_like(self)
other: zeros_like(other)
result: self_t.zero_()
- name: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
self: "weight.isComplex() ? grad * (1 - weight.conj().toComplexDouble()) : grad * (1 - weight.toDouble())"
end: grad * weight.conj()
result: at::lerp(self_t, end_t, weight)
- name: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor
self: grad * (1 - weight).conj()
end: grad * weight.conj()
weight: grad * (end - self).conj()
result: at::lerp(self_t, end_t, weight_p) + weight_t * (end_p - self_p)
- name: lgamma(Tensor self) -> Tensor
self: grad * digamma(self)
result: auto_element_wise
- name: digamma(Tensor self) -> Tensor
self: grad * polygamma(1, self)
result: auto_element_wise
- name: polygamma(int n, Tensor self) -> Tensor
self: grad * polygamma(n + 1, self)
result: auto_element_wise
- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
self: grad * polygamma(n + 1, self)
result: self_t.mul_(polygamma(n + 1, original_self_p))
- name: log(Tensor self) -> Tensor
self: grad.div(self.conj())
result: auto_element_wise
- name: log10(Tensor self) -> Tensor
self: grad / (self.conj() * 2.3025850929940456)
result: auto_element_wise
- name: log1p(Tensor self) -> Tensor
self: log1p_backward(grad, self)
result: auto_element_wise
- name: log2(Tensor self) -> Tensor
self: grad / (self.conj() * 0.6931471805599453)
result: auto_element_wise
- name: logaddexp(Tensor self, Tensor other) -> Tensor
self: grad / (1 + exp(other - self))
other: grad / (1 + exp(self - other))
result: self_t / (1 + exp(other_p - self_p)) + other_t / (1 + exp(self_p - other_p))
- name: logaddexp2(Tensor self, Tensor other) -> Tensor
self: grad / (1 + pow(2, other - self))
other: grad / (1 + pow(2, self - other))
result: self_t / (1 + pow(2, other_p - self_p)) + other_t / (1 + pow(2, self_p - other_p))
# Note [Gradient formula for xlogy at x = 0, y <= 0]
# x * log(y) is not defined at y <= 0, so we cannot even talk about differentiability
# Now, xlogy(0, y) = 0 by definition.
# This does not make it differentiable as it's not defined in a neighbourhood of a point
# (0, y) when y <= 0.
# Now, when a function is non-differentiable, sometimes we return "a relatively sensible value"
# In this case, as per the discussion in https://github.com/pytorch/pytorch/issues/80770, we choose
# this value to be zero, which is the directional derivative along the line {x = 0}.
- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
self: at::xlogy(grad, other).masked_fill((self == 0.) & (other <= 0.), 0.)
other: grad * self / other
result: at::xlogy(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= 0.), 0.) + other_t * self_p / other_p
- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
other: grad * self / other
result: auto_element_wise
- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
self: "other.toDouble() > 0.
? at::xlogy(grad, other)
: at::xlogy(grad, other).masked_fill(self == 0., 0.)"
result: auto_element_wise
# See Note [Gradient formula for xlogy at x = 0, y <= 0]
# Same here but with y <= -1
- name: special_xlog1py(Tensor self, Tensor other) -> Tensor
self: at::special_xlog1py(grad, other).masked_fill((self == 0.) & (other <= -1.), 0.)
other: grad * self / (other + 1)
result: at::special_xlog1py(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= -1.), 0.) + other_t * self_p / (other_p + 1)
- name: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor
other: grad * self / (other + 1)
result: auto_element_wise
- name: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor
self: "other.toDouble() > -1.
? at::special_xlog1py(grad, other)
: at::special_xlog1py(grad, other).masked_fill(self == 0., 0.)"
result: auto_element_wise
- name: special_zeta(Tensor self, Tensor other) -> Tensor
self: not_implemented("zeta")
other: grad * -self * special_zeta(self + 1., other)
- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
other: grad * -self * special_zeta(self.toDouble() + 1., other)
- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
self: not_implemented("zeta")
- name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
self: logsumexp_backward(grad, self, result, dim, keepdim)
result: logsumexp_jvp(self_p, self_t, dim, keepdim)
- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
self, b: linalg_lstsq_backward(grad, self, b, rcond, driver, grad_input_mask)
solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t)
output_differentiability: [True, False, False, False]
- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
result: self_t.zero_()
- name: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
self: zeros_like(self)
other: zeros_like(other)
result: self_t.zero_()
- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
A: lu_factor_ex_backward(grad, LU, pivots, pivot)
LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot)
output_differentiability: [True, False, False]
- name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)
A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot)
L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot))
U: std::get<1>(linalg_lu_jvp(A_t, P, L, U, pivot))
output_differentiability: [False, True, True]
- name: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor
LU: linalg_lu_solve_LU(grad, LU, pivots, result, left, adjoint)
B: "at::linalg_lu_solve(LU, pivots, grad, left, !adjoint)"
result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint)
- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.size(-2), LU_data.size(-1))
LU_pivots: non_differentiable
L: "LU_data_t.size(-2) >= LU_data_t.size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow(-1, 0, LU_data_t.size(-2)).tril(-1)"
U: "LU_data_t.size(-1) >= LU_data_t.size(-2) ? LU_data_t.triu() : LU_data_t.narrow(-2, 0, LU_data_t.size(-1)).triu()"
output_differentiability: [False, True, True]
- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
self: grad.masked_fill(mask, 0)
mask: non_differentiable
result: self_t.masked_fill(mask, 0)
- name: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
self: grad.masked_fill(mask, 0)
value: masked_fill_backward(grad, mask)
mask: non_differentiable
result: self_t.masked_fill(mask, value_t)
- name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
self: grad.masked_fill(mask, 0)
source: masked_scatter_backward(grad, mask, source.sizes())
mask: non_differentiable
result: self_t.masked_scatter(mask, source_t)
- name: masked_select(Tensor self, Tensor mask) -> Tensor
self: masked_select_backward(grad, self, mask)
mask: non_differentiable
result: auto_linear
- name: linalg_matrix_exp(Tensor self) -> Tensor
self: linalg_matrix_exp_differential(self, grad, /*adjoint*/ true)
result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false)
- name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: max(Tensor self) -> Tensor
self: evenly_distribute_backward(grad, self, result)
result: evenly_read_jvp(self_t, self_p, result)
- name: maximum(Tensor self, Tensor other) -> Tensor
self: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0)
other: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0)
result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p > other_p).to(result.scalar_type())) * (self_t - other_t)
- name: fmax(Tensor self, Tensor other) -> Tensor
self: grad.masked_fill((self >= other).logical_or_(other.isnan()).logical_not_(), 0)
other: grad.masked_fill((self >= other).logical_or_(other.isnan()), 0)
result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
self: grad.expand(self.sizes()) / self.numel()
result: auto_linear
- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: mean_backward(grad, self.sym_sizes(), dim, self.sym_numel(), keepdim)
result: auto_linear
- name: median(Tensor self) -> Tensor
self: evenly_distribute_backward(grad, self, result)
result: evenly_read_jvp(self_t, self_p, result)
- name: nanmedian(Tensor self) -> Tensor
self: evenly_distribute_backward(grad, self, result)
result: evenly_read_jvp(self_t, self_p, result)
# This is in theory incorrect in the following case:
# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
# | at middle position of the
# | list between two `b`s. E.g.,
# |
# ^the middle position
# The gradient exists and is essentially 0 in this case.
#
# In case where the middle position is at the boundary of `b` range, e.g.,
# sorted list: [..., a, b, b, ..., b, b, c, ...]
# |
# ^the middle position
# The backward implementation is correct in the sense that it returns the
# subgradient on one side.
- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: min(Tensor self) -> Tensor
self: evenly_distribute_backward(grad, self, result)
result: evenly_read_jvp(self_t, self_p, result)
- name: minimum(Tensor self, Tensor other) -> Tensor
self: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0)
other: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0)
result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p < other_p).to(result.scalar_type())) * (self_t - other_t)
- name: fmin(Tensor self, Tensor other) -> Tensor
self: grad.masked_fill((self <= other).logical_or_(other.isnan()).logical_not_(), 0)
other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0)
result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
- name: mm(Tensor self, Tensor mat2) -> Tensor
self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1)
mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1)
result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t)
- name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
- name: mul.Tensor(Tensor self, Tensor other) -> Tensor
self: mul_tensor_backward(grad, other, self.scalar_type())
other: mul_tensor_backward(grad, self, other.scalar_type())
result: other_t * self_p + self_t * other_p
- name: mul.Scalar(Tensor self, Scalar other) -> Tensor
self: mul_tensor_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type())
result: self_t * other
- name: mv(Tensor self, Tensor vec) -> Tensor
self: grad.ger(vec.conj())
vec: self.conj().t().mv(grad)
result: mv(self_t, vec_p) + mv(self_p, vec_t)
- name: mvlgamma(Tensor self, int p) -> Tensor
self: mvlgamma_backward(grad, self, p)
result: auto_element_wise
- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
self: grad * at::isfinite(self)
result: auto_element_wise
- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)
- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
save_mean: not_implemented("native_batch_norm_backward save_mean")
save_invstd: not_implemented("native_batch_norm_backward save_invstd")
- name: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? native_layer_norm_backward_symint(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape)
- name: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
bias: Tensor()
mean: not_implemented("native_layer_norm_backward mean")
rstd: not_implemented("native_layer_norm_backward rstd")
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
result1: group_norm_mean_jvp(input_t, result1, group)
result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group)
- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)
result: self_t.zero_()
- name: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
self: zeros_like(self)
other: zeros_like(other)
result: self_t.zero_()
- name: neg(Tensor self) -> Tensor
self: grad.neg()
result: auto_element_wise
- name: nextafter(Tensor self, Tensor other) -> Tensor
self: not_implemented("nextafter")
other: not_implemented("nextafter")
- name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor
self: norm_backward(grad, self, p, result)
result: norm_jvp(self_p, self_t, p, result)
- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
self: norm_backward(grad, self, p, result, dim, keepdim)
result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
- name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
self: norm_backward(grad, self.to(grad.scalar_type()), p, result)
result: norm_jvp(self_p, self_t, p, result)
- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim)
result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
- name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim)
result: linalg_vector_norm_jvp(self_p, self_t, ord, result, dim, keepdim)
- name: _pdist_forward(Tensor self, float p=2) -> Tensor
self: _pdist_backward(grad, self, p, result)
- name: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor
grad: not_implemented("_pdist_backward")
self: not_implemented("_pdist_backward")
pdist: not_implemented("_pdist_backward")
- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
x1, x2: _euclidean_dist_backward(grad, x1, x2, result)
- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
x1: _cdist_backward(grad.contiguous(), x1, x2, p, result)
x2: _cdist_backward(grad.mT().contiguous(), x2, x1, p, result.mT().contiguous())
- name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
grad: not_implemented("_cdist_backward")
x1: not_implemented("_cdist_backward")
x2: not_implemented("_cdist_backward")
cdist: not_implemented("_cdist_backward")
- name: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
mean: at::zeros(mean.sizes(), grad.options())
result: auto_element_wise
- name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
std: at::zeros(std.sizes(), grad.options())
result: auto_element_wise
- name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
mean: at::zeros(mean.sizes(), grad.options())
std: at::zeros(std.sizes(), grad.options())
result: zeros_like(mean_t)
- name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
input, tau: householder_product_backward(grad, result, input, tau)
result: householder_product_jvp(input_t, tau_t, result, input_p, tau_p)
- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor
self: not_implemented("ormqr")
input2: not_implemented("ormqr")
input3: not_implemented("ormqr")
- name: permute(Tensor(a) self, int[] dims) -> Tensor(a)
self: permute_backwards(grad, dims)
result: auto_linear
- name: poisson(Tensor self, Generator? generator=None) -> Tensor
self: zeros_like(self)
result: auto_element_wise
- name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
self: pow_backward(grad, self, exponent)
result: auto_element_wise
- name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
self: pow_backward_self(grad, self, exponent)
exponent: pow_backward_exponent(grad, self, exponent, result)
result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj()
- name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor
exponent: pow_backward_exponent(grad, self, exponent, result)
result: auto_element_wise
- name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
self: prod_backward(grad, self.to(grad.scalar_type()), result)
result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result) * self_t.conj()).sum().conj()
- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()
- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
index: non_differentiable
source: grad.take(index).reshape_as(source)
result: self_t.put(index, source_t, accumulate)
- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
Q, R: linalg_qr_jvp(A_t, Q, R, mode)
- name: rad2deg(Tensor self) -> Tensor
self: rad2deg_backward(grad)
result: auto_element_wise
- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: reciprocal(Tensor self) -> Tensor
self: -grad * (result * result).conj()
result: auto_element_wise
- name: remainder.Scalar(Tensor self, Scalar other) -> Tensor
self: grad
result: auto_element_wise
- name: remainder.Tensor(Tensor self, Tensor other) -> Tensor
self: grad
other: -grad * self.div(other, /*rounding_mode=*/"floor")
result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"floor")
- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
self: renorm_backward(grad, self, p, dim, maxnorm)
- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
self: repeat_backward(grad, repeats, self.sym_sizes())
result: auto_linear
- name: special_entr(Tensor self) -> Tensor
self: grad * (-(1 + self.log()))
result: auto_element_wise
- name: special_ndtri(Tensor self) -> Tensor
self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
result: auto_element_wise
- name: special_log_ndtr(Tensor self) -> Tensor
self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp()
result: auto_element_wise
# [Note: Sometimes view derivatives]
# The following situation applies to other operations as well.
# TODO: This note is only referenced once by to_dense. Make this
# more generic if it's been referenced more than once.
#
# DO NOT define a backward for reshape!
# reshape is special in that it sometimes returns a view, and sometimes not.
# Defining a backward will make codegen spit out the forward call as
# as_variable(baseType->reshape(self)),
# making it impossible (hard) to detect when it is actually a view.
# - name: reshape(Tensor self, IntArrayRef shape)
- name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
self: grad.reshape_symint(self.sym_sizes())
result: auto_linear
- name: round(Tensor self) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
- name: round.decimals(Tensor self, *, int decimals) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
- name: rsqrt(Tensor self) -> Tensor
self: -0.5 * grad * result.pow(3).conj()
result: auto_element_wise
- name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
self: grad.scatter(dim, index, 0)
index: non_differentiable
src: grad.gather(dim, index)
result: self_t.scatter(dim, index, src_t)
- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
self: grad.scatter(dim, index, 0)
index: non_differentiable
result: self_t.scatter(dim, index, 0)
- name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
self: grad
index: non_differentiable
src: grad.gather(dim, index)
result: scatter_add(self_t, dim, index, src_t)
- name: select.int(Tensor(a) self, int dim, int index) -> Tensor(a)
dispatch:
Default:
self: select_backward(grad, self.sizes(), dim, index)
result: auto_linear
AutogradNestedTensor:
self: _nested_select_backward(grad, self, dim, index)
- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, int index) -> Tensor
grad_output: grad.select(dim, index)
result: auto_linear
- name: sigmoid(Tensor self) -> Tensor
self: sigmoid_backward(grad, result)
result: auto_element_wise
- name: logit(Tensor self, float? eps=None) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_logit_backward(grad, self, eps) : logit_backward(grad, self, eps)"
result: auto_element_wise
- name: sign(Tensor self) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
- name: sgn(Tensor self) -> Tensor
self: sgn_backward(self, grad, result)
# Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric)
# The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian
# auto_element_wise has a name that's a bit deceiving in the complex case
result: sgn_backward(self_p, self_t, result)
- name: sin(Tensor self) -> Tensor
self: grad * self.cos().conj()
result: auto_element_wise
- name: sinc(Tensor self) -> Tensor
self: sinc_backward(grad, self)
result: auto_element_wise
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()
result: auto_element_wise
- name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step)
result: auto_linear
- name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
grad_output: grad.slice_symint(dim, start, end, step)
result: auto_linear
- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step)
src: grad.slice_symint(dim, start, end, step)
result: auto_linear
- name: select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor
self: select_scatter(grad, zeros_like(src), dim, index)
src: grad.select(dim, index)
result: auto_linear
- name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2)
src: grad.diagonal(offset, dim1, dim2)
result: auto_linear
- name: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset)
# See Note [as_strided_scatter backward support]
src: grad.contiguous().as_strided_symint(size, stride, storage_offset)
result: auto_linear
- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)
A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1])
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())"
output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
output_differentiability: [True, False]
values: gather_with_keepdimed_indices(self_t, dim, indices, true)
- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
output_differentiability: [True, False]
values: gather_with_keepdimed_indices(self_t, dim, indices, true)
- name: split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]
self: split_backward(grads, split_size, dim, self.sizes(), self.options())
result: auto_linear
- name: unsafe_split.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
self: split_backward(grads, split_size, dim, self.sizes(), self.options())
result: auto_linear
- name: split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.options())
result: auto_linear
- name: unsafe_split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.options())
result: auto_linear
- name: sqrt(Tensor self) -> Tensor
self: grad / (2 * result.conj())
result: auto_element_wise
- name: squeeze(Tensor(a) self) -> Tensor(a)
self: unsqueeze_to(grad, self.sizes())
result: auto_linear
- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
self: unsqueeze_to(grad, dim, self.sizes())
result: auto_linear
- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
self: unsqueeze_to(grad, self.sizes())
result: auto_linear
- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
self: unsqueeze_to(grad, dim, self.sizes())
result: auto_linear
- name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
self: std_backward(result, grad, self, dim, correction, keepdim)
# pointwise (variance) + sum + sqrt
result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0)
- name: std_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim)
result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0)
# linear
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
- name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
other: handle_r_to_c(other.scalar_type(), maybe_multiply(-grad, alpha.conj()))
result: self_t - maybe_multiply(other_t, alpha)
- name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
result: auto_element_wise
- name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj()))
other: handle_r_to_c(other.scalar_type(), grad)
result: -maybe_multiply(self_t, alpha) + other_t
- name: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj()))
result: auto_element_wise
- name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
self: grad.expand_symint(self.sym_sizes())
result: auto_linear
- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
dispatch:
Default:
self: sum_backward(grad, self.sym_sizes(), dim, keepdim)
result: auto_linear
AutogradNestedTensor:
# TODO: replace this function once semantics for nested tensor expand have been settled on
self: _nested_sum_backward(grad, self, dim, keepdim)
- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)
result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype)
# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here
- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U,
grad_S,
full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh,
full_matrices ? U.narrow(-1, 0, S.size(-1)) : U,
S,
full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)"
U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices)
- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors)
- name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)
A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true)
eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true)
- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
self: handle_r_to_c(self.scalar_type(), linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/false))
eigenvalues, eigenvectors: linalg_eig_jvp(self_t, eigenvalues, eigenvectors, /*is_hermitian=*/false)
- name: t(Tensor(a) self) -> Tensor(a)
self: grad.t()
result: auto_linear
- name: t_(Tensor(a!) self) -> Tensor(a!)
self: grad.t()
result: auto_linear
- name: one_hot(Tensor self, int num_classes=-1) -> Tensor
self: non_differentiable
- name: flip(Tensor self, int[] dims) -> Tensor
self: grad.flip(dims)
result: auto_linear
- name: roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor
self: grad.roll(fmap(reverse_list(shifts), [](int64_t i){return -i;}), reverse_list(dims))
result: auto_linear
- name: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
self: grad.rot90(-k, dims)
result: auto_linear
- name: take(Tensor self, Tensor index) -> Tensor
self: take_backward(grad, self, index)
index: non_differentiable
result: auto_linear
- name: tan(Tensor self) -> Tensor
self: grad * (1 + result.pow(2)).conj()
result: auto_element_wise
- name: tanh(Tensor self) -> Tensor
self: tanh_backward(grad, result)
result: auto_element_wise
- name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
output_differentiability: [True, False]
values: gather(self_t, dim, indices)
- name: trace(Tensor self) -> Tensor
self: trace_backward(grad, self.sizes())
result: auto_linear
- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
self: grad.transpose(dim0, dim1)
result: auto_linear
- name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
self: grad.transpose(dim0, dim1)
result: auto_linear
- name: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)
self, A: triangular_solve_backward(grad_solution, grad_cloned_coefficient, self, A, solution, upper, transpose, unitriangular, grad_input_mask)
solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular)
cloned_coefficient: A_t
- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
- name: tril(Tensor self, int diagonal=0) -> Tensor
self: grad.tril(diagonal)
result: auto_linear
- name: triu(Tensor self, int diagonal=0) -> Tensor
self: grad.triu(diagonal)
result: auto_linear
- name: trunc(Tensor self) -> Tensor
self: zeros_like(grad)
result: auto_element_wise
# DO NOT define a backward for to_dense
# See [Note: Sometimes view derivatives]
# - name: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
#
- name: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
self: to_dense_backward(grad, self)
- name: to_sparse(Tensor self) -> Tensor
self: grad.to_dense()
- name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
self: grad.to_dense()
- name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor
self: to_mkldnn_backward(grad, self)
- name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
self: unfold_backward(grad, self.sizes(), dimension, size, step)
result: auto_linear
- name: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor
grad_in: grad.unfold(dim, size, step)
result: auto_linear
- name: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.zero_()
- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
output_differentiability: [True, False]
self: not_implemented("_unique")
- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("unique_dim")
- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("unique_consecutive")
- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("unique_dim_consecutive")
- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, False, False]
self: not_implemented("_unique2")
- name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor
self: grad.reshape_symint(self.sym_sizes())
result: auto_linear
- name: lift(Tensor self) -> Tensor
self: grad
result: auto_linear
- name: lift_fresh(Tensor(a) self) -> Tensor(a)
self: grad
result: auto_linear
- name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
self: grad.squeeze(dim)
result: auto_linear
- name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
self: grad.squeeze(dim)
result: auto_linear
- name: var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
self: var_backward(grad, self, dim, correction, keepdim)
# pointwise + sum
result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim))
- name: var_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim)
result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim))
# linear
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
dispatch:
Default:
self: grad.reshape_symint(self.sym_sizes())
result: auto_linear
AutogradNestedTensor:
self: grad.reshape_as(self)
result: auto_linear
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
output_differentiability: [False]
- name: view_as_real(Tensor(a) self) -> Tensor(a)
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
result: at::view_as_real(self_t)
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
result: at::view_as_complex(self_t)
- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
condition: non_differentiable
self: where(condition, grad, 0)
other: where(condition, 0, grad)
result: where(condition, self_t, other_t)
# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen
# to be running backward with create_graph=True, fall back to a backward function that uses
# differentiable ops.
- name: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)
v, g: "grad.defined() ? (GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_interface_backward(grad.contiguous(), v, g, result1, dim)) : std::tuple<Tensor, Tensor>()"
- name: zero_(Tensor(a!) self) -> Tensor(a!)
self: zeros_like(grad)
result: auto_linear
- name: sparse_mask(Tensor self, Tensor mask) -> Tensor
self: grad.to_dense().sparse_mask(mask).to_dense()
mask: non_differentiable
- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
values: sparse_constructor_values_backward(grad, indices)
- name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor
self: at::_sparse_sum_backward(grad, self, dim)
- name: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor
self: grad * _standard_gamma_grad(self, result)
- name: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
self: not_implemented("_standard_gamma_grad")
- name: values(Tensor(a) self) -> Tensor(a)
dispatch:
Default:
self: at::_sparse_coo_tensor_unsafe(self.indices(), grad, self.sizes())._coalesced_(true)
AutogradNestedTensor:
self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_offsets())
# Why is _values() not differentiable?
# See NOTE [ Sparse: autograd and API ]
- name: _values(Tensor(a) self) -> Tensor(a)
output_differentiability: [False]
# NN
- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
i1, i2, i3: _trilinear_backward(grad, i1, i2, i3, expand1, expand2, expand3, sumdim, grad_input_mask)
result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) +
_trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) +
_trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)"
- name: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor
self: constant_pad_nd_backward(grad, pad)
result: constant_pad_nd(self_t, pad, 0)
- name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
self: binary_cross_entropy_backward(grad, self, target, weight, reduction)
target: binary_cross_entropy_target_backward(grad, self, target, weight, reduction)
result: "apply_loss_reduction(
binary_cross_entropy_backward(self_t, self_p, target_p, weight, at::Reduction::None)
+ binary_cross_entropy_target_backward(target_t, self_p, target_p, weight, at::Reduction::None),
reduction)"
- name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction)
target: binary_cross_entropy_double_backward_target(grad, grad_output, self, target, weight, reduction)
grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction)
result: " binary_cross_entropy_double_backward(grad_output_p, self_t, self_p, target_p, weight, reduction)
+ binary_cross_entropy_double_backward_target(target_t, grad_output_p, self_p, target_p, weight, reduction)
+ binary_cross_entropy_double_backward_grad_output(grad_output_t, self_p, target_p, weight, reduction)"
- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor
self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction)
target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction)
result: "apply_loss_reduction(
binary_cross_entropy_with_logits_backward(self_t, self_p, target_p, weight, pos_weight, at::Reduction::None)
+ binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None),
reduction)"
- name: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
indices: non_differentiable
weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse)
result: auto_linear
- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
grad_output: embedding_dense_double_backward(grad, indices, padding_idx)
indices: non_differentiable
result: auto_linear
- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
indices: non_differentiable
offsets: non_differentiable
weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx)
per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx)
- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
indices: non_differentiable
offset2bag: non_differentiable
bag_size: non_differentiable
maximum_indices: non_differentiable
- name: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)
indices: non_differentiable
self: not_implemented("embedding_renorm")
- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: mse_loss_backward(grad, self, target, reduction)
target: mse_loss_backward(grad, target, self, reduction)
result: apply_loss_reduction(mse_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None).conj() + mse_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None).conj(), reduction)
- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor
self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction)
target: non_differentiable
- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)
self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target)
target: non_differentiable
- name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
self: nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
target: non_differentiable
output: std::get<0>(nll_loss_forward(self_t, target, weight, reduction, ignore_index))
- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)
self: nll_loss2d_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
target: non_differentiable
output: std::get<0>(nll_loss2d_forward(self_t, target, weight, reduction, ignore_index))
- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
self: smooth_l1_loss_backward(grad, self, target, reduction, beta)
target: smooth_l1_loss_backward(grad, target, self, reduction, beta)
result: apply_loss_reduction(smooth_l1_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, beta).conj() + smooth_l1_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, beta).conj(), reduction)
- name: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor
self: huber_loss_backward(grad, self, target, reduction, delta)
target: huber_loss_backward(grad, target, self, reduction, delta)
result: apply_loss_reduction(huber_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, delta).conj() + huber_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, delta).conj(), reduction)
- name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: soft_margin_loss_backward(grad, self, target, reduction)
result: apply_loss_reduction(soft_margin_loss_backward(self_t.conj(), self_p, target, at::Reduction::None).conj(), reduction)
- name: relu(Tensor self) -> Tensor
self: threshold_backward(grad, result, 0)
result: auto_element_wise
- name: silu(Tensor self) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)"
result: auto_element_wise
- name: mish(Tensor self) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)"
result: auto_element_wise
- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self)
result: auto_element_wise
- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)
self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result)
result: self_t.copy_(elu_backward(original_self_t, alpha, scale, input_scale, /* is_result */ true, result))
- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self)
result: auto_element_wise
- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)
result: self_t.copy_(elu_backward(original_self_t, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result))
- name: gelu(Tensor self, *, str approximate='none') -> Tensor
self: gelu_backward(grad, self, approximate)
result: auto_element_wise
- name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor
grad_output: gelu_backward(grad, self, approximate)
self: gelu_double_backward(grad, grad_output, self, approximate)
result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate)
- name: glu(Tensor self, int dim=-1) -> Tensor
# TODO: glu_backward can benefit from forward result,
# and forward ad/forward over reverse ad for that matter
self: glu_backward(grad, self, dim)
result: glu_jvp(result, self_p, self_t, dim)
- name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
self: hardshrink_backward(grad, self, lambd)
result: auto_element_wise
- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
grad_out: hardshrink_backward(grad, self, lambd)
self: zeros_like(grad)
result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result))
- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
self: hardtanh_backward(grad, self, min_val, max_val)
result: auto_element_wise
- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
self: leaky_relu_backward(grad, self, negative_slope, false)
result: auto_element_wise
- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)
self: leaky_relu_backward(grad, result, negative_slope, true)
result: self_t.copy_(leaky_relu_backward(original_self_t.conj(), result, negative_slope, true).conj())
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj()
- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
result: self_t - logsumexp_jvp(self_p, self_t, {dim}, true)
- name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _sparse_log_softmax_backward_data(grad, result, dim, self)
- name: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor
self: _masked_softmax_backward(grad, result, mask, dim)
mask: non_differentiable
- name: prelu(Tensor self, Tensor weight) -> Tensor
self, weight: "grad.defined() ? prelu_backward(grad, self, weight) : std::tuple<Tensor, Tensor>()"
result: prelu_jvp(self_p, self_t, weight_p, weight_t)
- name: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight)
result0: prelu_backward_self_jvp(self_p, weight_p, weight_t, grad_output_p, grad_output_t)
result1: prelu_backward_weight_jvp(weight_p, self_p, self_t, grad_output_p, grad_output_t)
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
result: auto_element_wise
- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true))
- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _sparse_softmax_backward_data(grad, result, dim, self)
- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
self: sparse_sparse_matmul_backward(grad, self, other, 0)
other: sparse_sparse_matmul_backward(grad, self, other, 1)
- name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
self: softplus_backward(grad, self, beta, threshold)
result: auto_element_wise
- name: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
self: softshrink_backward(grad, self, lambd)
result: auto_element_wise
- name: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor
self: threshold_backward(grad, self, threshold)
result: auto_element_wise
- name: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
self: threshold_backward(grad, self, threshold)
result: self_t.copy_(threshold_backward(self_t.conj(), original_self_p, threshold).conj())
- name: reflection_pad1d(Tensor self, int[2] padding) -> Tensor
self: reflection_pad1d_backward(grad, self, padding)
result: auto_linear
- name: reflection_pad2d(Tensor self, int[4] padding) -> Tensor
self: reflection_pad2d_backward(grad, self, padding)
result: auto_linear
- name: reflection_pad3d(Tensor self, int[6] padding) -> Tensor
self: reflection_pad3d_backward(grad, self, padding)
result: auto_linear
- name: replication_pad1d(Tensor self, int[2] padding) -> Tensor
self: replication_pad1d_backward(grad, self, padding)
result: auto_linear
- name: replication_pad2d(Tensor self, int[4] padding) -> Tensor
self: replication_pad2d_backward(grad, self, padding)
result: auto_linear
- name: replication_pad3d(Tensor self, int[6] padding) -> Tensor
self: replication_pad3d_backward(grad, self, padding)
result: auto_linear
- name: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
self: upsample_linear1d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales)
result: auto_linear
- name: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
self: upsample_bilinear2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
result: auto_linear
- name: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
self: _upsample_bilinear2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
result: auto_linear
- name: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
self: upsample_bicubic2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
result: auto_linear
- name: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
- name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w)
result: auto_linear
- name: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
self: upsample_nearest1d_backward_symint(grad, output_size, self.sym_sizes(), scales)
result: auto_linear
- name: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
self: _upsample_nearest_exact1d_backward_symint(grad, output_size, self.sym_sizes(), scales)
result: auto_linear
- name: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
self: upsample_nearest2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w)
result: auto_linear
- name: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
self: _upsample_nearest_exact2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w)
result: auto_linear
- name: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
self: upsample_nearest3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w)
result: auto_linear
- name: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w)
result: auto_linear
- name: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
input: upsample_linear1d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors)
result: auto_linear
- name: upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
input: upsample_bilinear2d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors)
result: auto_linear
- name: _upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
input: _upsample_bilinear2d_aa_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors)
result: auto_linear
- name: upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
input: upsample_trilinear3d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors)
result: auto_linear
- name: upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
input: upsample_bicubic2d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors)
result: auto_linear
- name: _upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
input: _upsample_bicubic2d_aa_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors)
- name: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
input: upsample_nearest1d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors)
result: auto_linear
- name: _upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
input: _upsample_nearest_exact1d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors)
result: auto_linear
- name: upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
input: upsample_nearest2d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors)
result: auto_linear
- name: _upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
input: _upsample_nearest_exact2d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors)
result: auto_linear
- name: upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
input: upsample_nearest3d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors)
result: auto_linear
- name: _upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
input: _upsample_nearest_exact3d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors)
result: auto_linear
- name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
self: pixel_unshuffle(grad, upscale_factor)
result: auto_linear
- name: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
self: pixel_shuffle(grad, downscale_factor)
result: auto_linear
- name: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
self: _adaptive_avg_pool2d_backward(grad, self)
result: auto_linear
- name: _adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor
self: _adaptive_avg_pool3d_backward(grad, self)
result: auto_linear
- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
self: adaptive_max_pool2d_backward(grad, self, result1)
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
output_differentiability: [True, False]
- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
self: adaptive_max_pool3d_backward(grad, self, result1)
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
output_differentiability: [True, False]
- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
result: auto_linear
- name: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
result: auto_linear
- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)
self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1)
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
output_differentiability: [True, False]
- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)
self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1)
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
output_differentiability: [True, False]
- name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
input, weight, bias: linear_backward(input, grad, weight, grad_input_mask)
#mps
- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask)
- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
output_differentiability: [True, False]
- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
output_differentiability: [True, False]
- name: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor
self: max_pool_double_backward(grad, indices, 2)
indices: non_differentiable
result: auto_linear
- name: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor
self: max_pool_double_backward(grad, indices, 3)
indices: non_differentiable
result: auto_linear
- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups)
# TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution.
# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context
# by convolution_backward instead of being passed along from the forward pass.
- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false}))
result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false}))
result2: convolution_backward_jvp_grad_bias(grad_output_t, result2)
- name: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
input, weight, bias: "grad.defined() ? convolution_backward_overrideable(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> Tensor
self, weight, bias: "grad.defined() ? _slow_conv2d_backward(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask)
- name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad.contiguous(), self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad.contiguous(), self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
self: im2col(grad, kernel_size, dilation, padding, stride)
result: auto_linear
- name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
self: col2im_symint(grad, {self.sym_size(-2), self.sym_size(-1)}, kernel_size, dilation, padding, stride)
result: auto_linear
- name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
grad_output: _adaptive_avg_pool2d_symint(grad, {grad_output.sym_size(-2), grad_output.sym_size(-1)})
self: zeros_like(self)
result: _adaptive_avg_pool2d_backward(grad_output_t, self_p)
- name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor
grad_output: _adaptive_avg_pool3d(grad, { grad_output.size(-3), grad_output.size(-2), grad_output.size(-1) })
self: zeros_like(self)
result: _adaptive_avg_pool3d_backward(grad_output_t, self_p)
- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
grad_output: max_pool_double_backward(grad, indices, 2)
self: zeros_like(self)
result: auto_linear
- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
grad_output: max_pool_double_backward(grad, indices, 3)
self: zeros_like(self)
result: auto_linear
- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
self: zeros_like(self)
result: avg_pool2d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
- name: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
self: zeros_like(self)
result: avg_pool3d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor
grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result)
self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result)
result: elu_backward(grad_output_t, alpha, scale, input_scale, is_result, self_or_result_p) + elu_double_backward(self_or_result_t, grad_output_p, alpha, scale, input_scale, is_result, self_or_result_p)
- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor
grad_output: max_pool_double_backward(grad, indices, 2)
self: zeros_like(self)
result: auto_linear
- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor
grad_output: max_pool_double_backward(grad, indices, 3)
self: zeros_like(self)
result: auto_linear
- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
grad_output: glu_double_backward_grad_output(grad, self, dim)
self: glu_double_backward(grad, grad_output, self, dim)
result: glu_backward_jvp(result, grad_output_p, self_p, grad_output_t, self_t, dim)
- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor
grad_output: hardtanh_backward(grad, self, min_val, max_val)
self: zeros_like(grad)
result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result))
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
grad_output: log_sigmoid_backward(grad, self, buffer)
self: log_sigmoid_double_backward(grad * grad_output, self)
- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype())
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
grad_output: leaky_relu_backward(grad, self, negative_slope, false)
self: zeros_like(grad)
# leaky_relu_backward(grad_output, self, negative_slope, false)
# computes grad_output * at::where(self_p > 0, 1, negative_slope)
# so the jvp formula is the following:
# grad_output_t * at::where(self_p > 0, self_p.new_ones([]), negative_slope);
#
# leaky_relu_backward(grad_output, result, negative_slope, true)
# computes grad_output * at::where(result > 0, 1, negative_slope)
# under the assumption that `negative_slope` is positive (otherwise,
# it is not possible to compute the gradient).
#
# so the jvp formula is the following:
# grad_output_t * at::where(result_p > 0, result_p.new_ones([]), negative_slope);
# with the assumption that negative_slope is positive.
#
# Combined together that results in the following optimized kernel which
# also checks the assumption that negative_slope is positive when self_is_result
# is True:
result: leaky_relu_backward(grad_output_t, self_p, negative_slope, self_is_result)
- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
grad_output: max_pool_double_backward(grad, indices, 2)
self: zeros_like(self)
indices: non_differentiable
result: auto_linear
- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
grad_output: max_pool_double_backward(grad, indices, 3)
self: zeros_like(self)
indices: non_differentiable
result: auto_linear
- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: mse_loss_backward(grad, self, target, reduction)
self: mse_loss_double_backward(grad * grad_output, self, reduction)
target: -mse_loss_double_backward(grad * grad_output, target, reduction)
result: " mse_loss_double_backward(self_t * grad_output_p, self_p, reduction)
- mse_loss_double_backward(target_t * grad_output_p, target_p, reduction)
+ mse_loss_backward(grad_output_t, self_p, target_p, reduction)
"
- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor
grad_output: nll_loss(grad, target, weight, reduction, ignore_index)
self: zeros_like(grad)
target: non_differentiable
- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor
grad_output: nll_loss2d(grad, target, weight, reduction, ignore_index)
self: zeros_like(grad)
target: non_differentiable
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
self: zeros_like(grad)
result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false)
- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor
grad_output: reflection_pad1d(grad, padding)
self: zeros_like(self)
result: reflection_pad1d_backward(grad_output_t, self_p, padding)
- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, int[4] padding) -> Tensor
grad_output: reflection_pad2d(grad, padding)
self: zeros_like(self)
result: reflection_pad2d_backward(grad_output_t, self_p, padding)
- name: reflection_pad3d_backward(Tensor grad_output, Tensor self, int[6] padding) -> Tensor
grad_output: reflection_pad3d(grad, padding)
self: zeros_like(self)
result: reflection_pad3d_backward(grad_output_t, self_p, padding)
- name: replication_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor
grad_output: replication_pad1d(grad, padding)
self: zeros_like(self)
result: replication_pad1d_backward(grad_output_t, self_p, padding)
- name: replication_pad2d_backward(Tensor grad_output, Tensor self, int[4] padding) -> Tensor
grad_output: replication_pad2d(grad, padding)
self: zeros_like(self)
result: replication_pad2d_backward(grad_output_t, self_p, padding)
- name: replication_pad3d_backward(Tensor grad_output, Tensor self, int[6] padding) -> Tensor
grad_output: replication_pad3d(grad, padding)
self: zeros_like(self)
result: replication_pad3d_backward(grad_output_t, self_p, padding)
- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta.conj())
mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj())
mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj())
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta)
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
result: " smooth_l1_loss_double_backward(self_t * grad_output_p, self_p, target_p, reduction, beta)
- smooth_l1_loss_double_backward(target_t * grad_output_p, self_p, target_p, reduction, beta)
+ smooth_l1_loss_backward(grad_output_t, self_p, target_p, reduction, beta)
"
- name: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor
grad_output: huber_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, delta)
self: huber_loss_double_backward(grad * grad_output, self, target, reduction, delta)
target: -huber_loss_double_backward(grad * grad_output, self, target, reduction, delta)
- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor
grad_output: softplus_backward(grad, self, beta, threshold)
self: softplus_double_backward(grad * grad_output, self, beta, threshold)
result: "softplus_backward(grad_output_t, self_p, beta, threshold)
+ softplus_double_backward(self_t * grad_output_p, self_p, beta, threshold)"
- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype)
output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype())
- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction)
- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
grad_output: softshrink_backward(grad, self, lambd)
self: zeros_like(grad)
result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result))
- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
grad_output: threshold_backward(grad, self, threshold)
self: zeros_like(grad)
result: zeros_like(self_t) + threshold_backward(grad_output_t, self_p, threshold)
- name: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scales)
result: auto_linear
- name: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scales_h, scales_w)
result: auto_linear
- name: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w)
result: auto_linear
- name: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scales_h, scales_w)
result: auto_linear
- name: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w)
- name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w)
result: auto_linear
- name: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
grad_output: upsample_nearest1d_symint(grad, output_size, scales)
result: auto_linear
- name: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scales)
result: auto_linear
- name: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: upsample_nearest2d_symint(grad, output_size, scales_h, scales_w)
result: auto_linear
- name: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scales_h, scales_w)
result: auto_linear
- name: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: upsample_nearest3d_symint(grad, output_size, scales_d, scales_h, scales_w)
result: auto_linear
- name: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w)
result: auto_linear
- name: upsample_linear1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor
grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scale_factors)
result: auto_linear
- name: upsample_bilinear2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor
grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scale_factors)
result: auto_linear
- name: _upsample_bilinear2d_aa_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor
grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scale_factors)
result: auto_linear
- name: upsample_trilinear3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor
grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scale_factors)
result: auto_linear
- name: upsample_bicubic2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor
grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scale_factors)
result: auto_linear
- name: _upsample_bicubic2d_aa_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor
grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scale_factors)
- name: upsample_nearest1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor
grad_output: upsample_nearest1d_symint(grad, output_size, scale_factors)
result: auto_linear
- name: _upsample_nearest_exact1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor
grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scale_factors)
result: auto_linear
- name: upsample_nearest2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor
grad_output: upsample_nearest2d_symint(grad, output_size, scale_factors)
result: auto_linear
- name: _upsample_nearest_exact2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor
grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scale_factors)
result: auto_linear
- name: upsample_nearest3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor
grad_output: upsample_nearest3d_symint(grad, output_size, scale_factors)
result: auto_linear
- name: _upsample_nearest_exact3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor
grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scale_factors)
result: auto_linear
- name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor
grad_output: sigmoid_backward(grad, output.conj())
output: grad.conj() * grad_output * (-2 * output.conj() + 1)
result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1)
- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
grad_output: tanh_backward(grad, output.conj())
output: grad.conj() * (-2 * output.conj() * grad_output)
result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p)
# cudnn
- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)
- name: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)
- name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})"
- name: _mps_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor
self, weight: "grad.defined() ? mps_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor>()"
- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector<int64_t>(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})"
- name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output
self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple<Tensor, Tensor>()"
- name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W)
# NB: Why is the backwards here so complicated? CuDNN cannot be used to compute
# backward in evaluation mode, because the math for backward in evaluation mode
# is different (since the forward math is different), and CuDNN does not support
# it. And in any case, you shouldn't be using this bn in evaluation mode,
# because it should be merged into the previous convolution (left for future
# work.)
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)
# HACK: save_mean and save_var are going to be passed in as
# requires_grad variables (even though we'll never backprop through
# them) so we need to prevent the unpacking from triggering an error.
- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
save_mean: not_implemented("cudnn_batch_norm_backward save_mean")
save_var: not_implemented("cudnn_batch_norm_backward save_var")
reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace")
input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
# nnpack
- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor
# NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here.
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
#LSTM MPS
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
output_differentiability: [True, True, True, False, False]
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
# Only frst three of _cudnn_rnn outputs can have gradients.
# _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf)
- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
dropout_state: non_differentiable
output_differentiability: [True, True, True, False, False]
input, hx, cx, weight: "_cudnn_rnn_backward_symint(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"
- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
dropout_state: non_differentiable
input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
# miopen
- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
save_var: not_implemented("miopen_batch_norm_backward save_var")
input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
- name: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
dropout_state: non_differentiable
output_differentiability: [True, True, True, False, False]
input, hx, cx, weight: "miopen_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"
- name: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
dropout_state: non_differentiable
# mkldnn
- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)
- name: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
self: mkldnn_max_pool2d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
self: mkldnn_max_pool3d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor
self: mkldnn_adaptive_avg_pool2d_backward(grad, self)
- name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
self: grad.reshape(self.sizes())
# Nested Tensor
- name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
list: "grad.defined()? at::unbind(grad) : std::vector<Tensor>(list.size())"
- name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor
t: grad.to_padded_tensor(0, t.sizes())
mask: non_differentiable
- name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213)
cpu_nested_shape_example: non_differentiable
- name: to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor
self: at::_nested_from_padded(grad, self._nested_tensor_size())
padding: non_differentiable
- name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, int[] offsets) -> Tensor(a)
self: grad.values()
nested_size: non_differentiable
nested_strides: non_differentiable
# fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back()))
result: auto_linear
- name: _fft_c2r(Tensor self, int[] dim, int normalization, int last_dim_size) -> Tensor
self: fft_c2r_backward(grad, dim, normalization)
result: auto_linear
- name: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor
self: _fft_c2c(grad, dim, normalization, !forward)
result: auto_linear
- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
self: unbind_backward(grads, dim)
result: auto_linear
- name: stack(Tensor[] tensors, int dim=0) -> Tensor
tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors))
result: stack_jvp(tensors, dim)
# fused RNN kernels
# Only frst two of _thnn_fused_lstm_cell outputs can have gradients.
# _thnn_fused_lstm_cell outputs: (hy, cy, workspace)
- name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)
output_differentiability: [True, True, False]
input_gates, hidden_gates, cx, input_bias, hidden_bias: "GradMode::is_enabled() ? _thnn_differentiable_lstm_cell_backward(grads[0], grads[1], input_gates, hidden_gates, input_bias, hidden_bias, cx, result1) : _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())"
- name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)
input_gates, hidden_gates, hx, input_bias, hidden_bias: "grad.defined() ? (GradMode::is_enabled() ? _thnn_differentiable_gru_cell_backward(grad, input_gates, hidden_gates, hx, input_bias, hidden_bias) : _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())) : std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>()"
# PackedSequence helpers
- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
input: _pack_padded_sequence_backward(grad, input.sizes(), result1, batch_first)
# TH wrappers
- name: eq.Scalar(Tensor self, Scalar other) -> Tensor
output_differentiability: [False]
- name: eq.Tensor(Tensor self, Tensor other) -> Tensor
output_differentiability: [False]
- name: ge.Scalar(Tensor self, Scalar other) -> Tensor
output_differentiability: [False]
- name: ge.Tensor(Tensor self, Tensor other) -> Tensor
output_differentiability: [False]
- name: gt.Scalar(Tensor self, Scalar other) -> Tensor
output_differentiability: [False]
- name: gt.Tensor(Tensor self, Tensor other) -> Tensor
output_differentiability: [False]
- name: le.Scalar(Tensor self, Scalar other) -> Tensor
output_differentiability: [False]
- name: le.Tensor(Tensor self, Tensor other) -> Tensor
output_differentiability: [False]
- name: lt.Scalar(Tensor self, Scalar other) -> Tensor
output_differentiability: [False]
- name: lt.Tensor(Tensor self, Tensor other) -> Tensor
output_differentiability: [False]
- name: ne.Scalar(Tensor self, Scalar other) -> Tensor
output_differentiability: [False]
- name: ne.Tensor(Tensor self, Tensor other) -> Tensor
output_differentiability: [False]
- name: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
output_differentiability: [False]
- name: nonzero(Tensor self) -> Tensor
output_differentiability: [False]
- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial)
- name: _pin_memory(Tensor self, Device? device=None) -> Tensor
self: grad
- name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor
self: non_differentiable
other: non_differentiable
output_differentiability: [False]
- name: _test_warn_in_autograd(Tensor self) -> Tensor
self: warn_backwards(grad)
- name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor
dispatch:
Default:
self: grad.expand(self.sizes()) + 1
result: auto_linear
AutogradNestedTensor:
self: grad.mul(grad)
AutogradCUDA:
self: grad.expand(self.sizes()) * 2
- name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor
dispatch:
AutogradNestedTensor:
self: grad.mul(grad).add(grad)
- name: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)
dispatch:
Default:
self: grad.reshape_as(self)
AutogradCUDA:
self: grad.reshape_as(self) + 1
- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
output_differentiability: [False]
- name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor
self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result)
index: non_differentiable
result: scatter_reduce_jvp(self_p, self_t, dim, index, src_p, src_t, reduce, include_self, result)
- name: special_airy_ai(Tensor x) -> Tensor
x: non_differentiable
- name: special_bessel_j0(Tensor self) -> Tensor
self: non_differentiable
- name: special_bessel_j1(Tensor self) -> Tensor
self: non_differentiable
- name: special_bessel_y0(Tensor self) -> Tensor
self: non_differentiable
- name: special_bessel_y1(Tensor self) -> Tensor
self: non_differentiable
- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_modified_bessel_i0(Tensor self) -> Tensor
self: non_differentiable
- name: special_modified_bessel_i1(Tensor self) -> Tensor
self: non_differentiable
- name: special_modified_bessel_k0(Tensor self) -> Tensor
self: non_differentiable
- name: special_modified_bessel_k1(Tensor self) -> Tensor
self: non_differentiable
- name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor
x: non_differentiable
- name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor
x: non_differentiable
- name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
x: non_differentiable
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
n: non_differentiable
- name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
x: non_differentiable
- name: special_spherical_bessel_j0(Tensor x) -> Tensor
x: non_differentiable
|