1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202
|
# mypy: allow-untyped-defs
"""
This module defines runtime wrappers, which, based on previous analysis attempts to:
1. process the inputs and outputs
2. apply mutations
3. handle functionalized randomness
4. deduplicate inputs and consolidate views into their bases (see input_output_analysis)
"""
import builtins
import collections
import itertools
import pprint
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import (
compile_context,
CompileContext,
detect_fake_mode,
DuplicateInputs,
tracing,
TracingContext,
)
from torch._prims_common import CUDARngStateHelper
from torch._subclasses import FakeTensor
from torch.fx.experimental._backward_state import BackwardState
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .functional_utils import gen_alias_from_base
from .input_output_analysis import (
compute_overlapping_inputs,
create_synthetic_base_metadata,
remove_dupe_metadata,
)
from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling
from .schemas import (
AOTConfig,
InputAliasInfo,
MutationType,
OutputType,
PlainTensorMeta,
SubclassCreationMeta,
SubclassMeta,
TensorAlias,
ViewAndMutationMeta,
)
from .subclass_utils import (
requires_subclass_dispatch,
runtime_unwrap_tensor_subclasses,
wrap_tensor_subclasses,
)
from .traced_function_transforms import aot_dispatch_subclass
from .utils import (
call_func_at_runtime_with_args,
make_boxed_func,
normalize_as_list,
partial_flatten_asdict,
strict_zip,
)
zip = strict_zip
class CompilerWrapper:
"""
A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts:
1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc)
2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments)
Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate
caching on the compiled output, and re-wrapping the output via epilogues.
Extra metadata that is needed to compute pre or post compile can be passed in via attributes.
"""
def pre_compile(
self,
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
"""
Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs.
Args:
flat_fn: The function to compile
flat_args: Metadata from example inputs of the function to compile
aot_config: AOTConfig passed in at compile time
fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args
"""
return flat_fn, flat_args, fw_metadata
def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable:
"""
Given an output of the compiler, wrap it with information received from prologue.
Args:
compiled_fn: Callable after calling compiler_fn
aot_config: AOTConfig after calling prologue
runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps.
Example:
def wrapped_compiled_fn(args):
# do something with args, aot_config, fw_metadata
return compiled_fn(args)
return wrapped_compiled_fn
"""
return compiled_fn
# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic
# that needs to run after the compiled function.
#
# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime
# epilogue for a forward-only inference graph, or for an autograd.Function.apply function.
# This is because there are some minor differences in how we treat these cases at runtime:
# - resize_() is currently handled in the inference case, but not fully handled in the autograd case.
# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
@dataclass
class RuntimeWrapper(CompilerWrapper):
indices_of_inps_to_detach: List[int]
trace_joint: bool
disable_amp: bool
def post_compile(
self,
compiled_fn,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
return _create_runtime_wrapper(
compiled_fn,
runtime_metadata=runtime_metadata,
indices_of_inps_to_detach=self.indices_of_inps_to_detach,
trace_joint=self.trace_joint,
keep_input_mutations=aot_config.keep_inference_input_mutations,
disable_amp=self.disable_amp,
)
class NoopAliasHandler:
def __init__(self, info, runtime_metadata, trace_joint):
pass
def __call__(self, orig_inputs, fw_outs, out):
return out
def _unwrap_tensoralias(x):
assert isinstance(x, TensorAlias)
return x.alias
def _identity(x):
return x
class AliasOfInputHandler:
def __init__(self, info, runtime_metadata, trace_joint):
self.base_idx = info.base_idx
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad
self.functional_tensor = info.functional_tensor
self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out):
aliased_base_tensor = orig_inputs[self.base_idx]
return gen_alias_from_base(
aliased_base_tensor,
self.unwrap_out(out),
self.requires_grad,
self.functional_tensor,
replay_views=self.replay_views,
)
class IsInputHandler:
def __init__(self, info, runtime_metadata, trace_joint):
self.base_idx = info.base_idx
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
def __call__(self, orig_inputs, fw_outs, out):
aliased_base_tensor = orig_inputs[self.base_idx]
return aliased_base_tensor
class AliasOfIntermediateHandler:
def __init__(self, info, runtime_metadata, trace_joint):
if info.output_type in (
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
):
num_user_outputs = len(runtime_metadata.output_info)
self.base_idx = info.base_idx + num_user_outputs
else:
self.base_idx = info.base_idx
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad
self.functional_tensor = info.functional_tensor
self.replay_views = config.view_replay_for_aliased_outputs
def __call__(self, orig_inputs, fw_outs, out):
aliased_base_tensor = fw_outs[self.base_idx]
return gen_alias_from_base(
aliased_base_tensor,
self.unwrap_out(out),
self.requires_grad,
self.functional_tensor,
replay_views=self.replay_views,
)
_HANDLER_MAP = {
OutputType.non_alias: NoopAliasHandler,
OutputType.unsafe_view_alias: NoopAliasHandler,
OutputType.custom_function_view: NoopAliasHandler,
OutputType.alias_of_input: AliasOfInputHandler,
OutputType.is_input: IsInputHandler,
OutputType.alias_of_intermediate: AliasOfIntermediateHandler,
OutputType.alias_of_intermediate_save_as_output: AliasOfIntermediateHandler,
OutputType.alias_of_intermediate_base_is_user_output: AliasOfIntermediateHandler,
}
def make_output_handler(info, runtime_metadata, trace_joint):
handler_type = _HANDLER_MAP[info.output_type]
return handler_type(info, runtime_metadata, trace_joint)
def _create_runtime_wrapper(
compiled_fn,
*,
runtime_metadata: ViewAndMutationMeta,
indices_of_inps_to_detach: List[int],
trace_joint: bool,
keep_input_mutations: bool,
disable_amp: bool,
):
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
# Note [Inputs needed in runtime epilogue after list clearing]
# In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to
# wrap the input arguments in a list, and clear the list from within the function.
# Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`.
#
# This is needed for Compiled Autograd since some of the inputs (activations) should be freed early.
# However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs
# **after** the compiled function has finished running. There are two main cases:
# (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input.
# (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`,
# and doing so requires us accessing the corresponding input after the compiled artifact has run.
epilogue_args_idx = []
epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices)
for info in runtime_metadata.output_info:
if (
info.output_type == OutputType.alias_of_input
or info.output_type == OutputType.is_input
):
assert isinstance(info.base_idx, int)
epilogue_args_idx.append(info.base_idx)
if config.unlift_effect_tokens:
assert len(runtime_metadata.tokens) == 0
replay_views = config.view_replay_for_aliased_outputs
if runtime_metadata.num_outputs_aliased > 0:
output_handlers = tuple(
make_output_handler(info, runtime_metadata, trace_joint)
for info in runtime_metadata.output_info
)
def runtime_wrapper(args: List[Any]):
# stash a ref to each input tensor we plan to use after the compiled function
orig_inputs = {i: args[i] for i in epilogue_args_idx}
if keep_input_mutations:
mutated_args = (
args[i]
for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd
)
torch.autograd.graph.increment_version(mutated_args)
if trace_joint:
args_ = list(args)
# See Note [Detaching inputs that never need gradients]
for idx in indices_of_inps_to_detach:
if isinstance(args_[idx], torch.Tensor):
args_[idx] = args_[idx].detach()
# It's possible to have trace_joint inside user specified with no_grad() region,
# if there is a nested with enable_grad(), that forces some outputs to require gradients.
# Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.
with torch.autograd._force_original_view_tracking(
True
), torch.enable_grad():
all_outs = call_func_at_runtime_with_args(
compiled_fn, args_, disable_amp=disable_amp, steal_args=True
)
else:
# When we have an inference graph, we run with grad disabled.
# It's possible to get an inference graph with inputs that require grad,
# in which case we want to make sure autograd is disabled
# (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on)
# NOTE: We use _set_grad_enabled directly to reduce runtime overhead
grad_enabled = torch.is_grad_enabled()
try:
if grad_enabled:
torch._C._set_grad_enabled(False)
all_outs = call_func_at_runtime_with_args(
compiled_fn, args, disable_amp=disable_amp, steal_args=True
)
finally:
if grad_enabled:
torch._C._set_grad_enabled(True)
del args
num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices
num_intermediate_bases = runtime_metadata.num_intermediate_bases
assert (
len(all_outs)
== num_mutated_runtime_inps
+ runtime_metadata.num_outputs
+ num_intermediate_bases
)
# Step 3: After running the compiled fw, apply updates to mutated inputs
num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
if num_mutations_to_apply > 0:
updated_inputs = all_outs[:num_mutations_to_apply]
fw_outs = all_outs[num_mutations_to_apply:]
for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices):
meta = runtime_metadata.input_info[inpt_idx]
if not meta.mutates_data and not meta.mutates_metadata:
continue
original_inpt = orig_inputs[inpt_idx]
updated_inpt = updated_inputs[i]
if meta.mutates_storage_metadata:
# See Note [set_() Input Mutations in AOTAutograd]
# mutates_storage_metadata means our input saw a x.set_(y) call.
# What if x **also** saw a data and/or a metadata mutation?
# (1) If the [meta]data mutation occurred after the set_(),
# then there is no need to copy_() the data.
# When we perform x.set_(x_updated), we are guaranteed that
# x_updated already has the final version of the data/metadata
# (2) If a data mutation occurred before the set_().
# This case seems very difficult to support.
# TODO: discuss on the PR and decide if we want to tr to
# either support it, or detect and ban it.
if trace_joint:
assert isinstance(updated_inpt, TensorAlias)
updated_inpt = updated_inpt.alias
with torch.no_grad():
original_inpt.set_(updated_inpt)
continue
if meta.mutates_metadata and not meta.mutates_data:
if trace_joint:
assert isinstance(updated_inpt, TensorAlias)
updated_inpt = updated_inpt.alias
# We need to grab the size/stride/storage_offset from the compiled forward,
# and use that to mutate the metadata of the input
original_inpt.as_strided_(
updated_inpt.size(),
updated_inpt.stride(),
updated_inpt.storage_offset(),
)
else:
if meta.mutates_data and meta.mutates_metadata:
original_inpt.as_strided_(
updated_inpt.size(),
updated_inpt.stride(),
updated_inpt.storage_offset(),
)
else:
assert meta.mutates_data
if meta.is_leaf and original_inpt.requires_grad:
# We can hit this situation in this case:
# def f(x):
# x.detach().mul_(2)
# return x + 1
# AOTAutograd will see a mutation in the above case, and try to
# apply a copy_() here, in the epilogue.
# But if x required gradients, and is a leaf, then autograd
# will yell at us for trying to mutate it.
# However, it's only possible to end up in this scenario (like the above)
# if all of the mutations to the leaf input were non-autograd-tracking mutations
# (aka mutations under no_grad(), or on detached views).
# In that case, we fully want to hide the mutation from autograd, so detaching is ok.
original_inpt.detach().copy_(updated_inpt)
else:
original_inpt.copy_(updated_inpt)
else:
fw_outs = all_outs
# Step 4: Manually regenerate any outputs that are aliased to inputs, instead of
# compiling them.
if runtime_metadata.num_outputs_aliased > 0:
# The compiled forward also returned intermediate bases. We don't want to return them to the user.
expect_num_outputs = (
len(output_handlers) + runtime_metadata.num_intermediate_bases
)
assert len(fw_outs) == expect_num_outputs
ret_outs = [
handler(orig_inputs, fw_outs, out)
for out, handler in builtins.zip(fw_outs, output_handlers)
]
else:
ret_outs = fw_outs
if runtime_metadata.dynamic_outputs:
for t, o in zip(ret_outs, runtime_metadata.output_info):
if o.dynamic_dims is None:
continue
if hasattr(t, "_dynamo_weak_dynamic_indices"):
t._dynamo_weak_dynamic_indices |= o.dynamic_dims
else:
t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
if runtime_metadata.grad_enabled_mutation is not None:
torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation)
return ret_outs
return runtime_wrapper
@dataclass
class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
# TODO: I would love to get rid of this argument, but it's
# Wrapped pretty tightly around our aot_dispatch_autograd logic.
# Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices
# for setting placeholder strides(which is done before runtime, before this wrapper runs)
# and for saving tensors for backward (which is done during runtime, after this wrapper runs)
# So in aot_dispatch_autograd, this wrapper can't edit the set of outs without making one
# of those two indices incorrect.
return_new_outs: bool = True
def pre_compile(
self,
flat_fn,
flat_args,
aot_config,
*,
fw_metadata,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
if config.functionalize_rng_ops:
# Update example inputs for the fw_compiler
fake_mode = detect_fake_mode()
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
flat_args.extend([seed, offset])
# We are not clearing flat_args here because
# 1) There is a check in the debug compiler at the end
# 2) It does not matter as these are fake tensors
return flat_fn, flat_args, fw_metadata
def post_compile(
self,
compiled_fn,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
@wraps(compiled_fn)
def wrapper(runtime_args: List[Any]):
if runtime_metadata.is_rng_op_functionalized:
# Add the seed and offset to args
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
runtime_args.extend([seed, offset])
out = compiled_fn(runtime_args)
out = self._functionalized_rng_runtime_epilogue(
runtime_metadata,
out,
# TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper
runtime_metadata.num_forward_returns,
)
return out
return compiled_fn(runtime_args)
return wrapper
# Calling convention: If we are running functionalized RNG, then outs consists
# of (user_outs, rng_offset)
def _functionalized_rng_runtime_epilogue(
self,
metadata: ViewAndMutationMeta,
outs,
offset_index,
):
if metadata.is_rng_op_functionalized:
assert metadata.num_outputs_rng_offset == 1
new_rng_offset = outs[offset_index]
CUDARngStateHelper.set_new_offset(new_rng_offset)
if self.return_new_outs:
user_outs = outs[:offset_index] + outs[offset_index + 1 :]
return user_outs
else:
return outs
return outs
@dataclass
class FakifiedOutWrapper(CompilerWrapper):
out_metas: List[torch.Tensor] = field(default_factory=list)
# TracingContext.fwd_output_strides
# Generated from actually doing compile
fwd_output_strides: Optional[List[List[int]]] = None
needs_post_compile: bool = True
def pre_compile(
self,
fw_module, # Must be fw_module from aot_dispatch_*_graph
flat_args,
aot_config,
*,
fw_metadata,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context and tracing_context.fakify_first_call:
self.out_metas = [
n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0])
]
else:
self.needs_post_compile = False
return fw_module, flat_args, fw_metadata
def _compute_output_meta_with_inductor_strides(self):
out = self.out_metas
fwd_output_strides = self.fwd_output_strides
if not fwd_output_strides:
return out
from torch.fx.experimental.symbolic_shapes import statically_known_true
for i in range(len(out)):
if not isinstance(out[i], Tensor):
continue
if all(
statically_known_true(s1 == s2)
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
):
continue
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
return out
# To be called post compile
def set_fwd_output_strides(self, fwd_output_strides):
self.fwd_output_strides = fwd_output_strides
def post_compile(
self,
compiled_fn,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
if self.needs_post_compile:
assert self.fwd_output_strides is not None
fakified_out = self._compute_output_meta_with_inductor_strides()
@wraps(compiled_fn)
def wrapper(runtime_args):
nonlocal fakified_out
if fakified_out is not None:
out = fakified_out
fakified_out = None
return out
return compiled_fn(runtime_args)
return wrapper
# If we don't need to fakify, we can just return the original compiled function
return compiled_fn
# This wrapper handles the AOTDispatch runtime logic for tensor subclasses.
# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor,
# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs).
# This function handles the wrapping and unwrapping of tensor subclasses at runtime.
@dataclass
class AOTDispatchSubclassWrapper(CompilerWrapper):
trace_joint: bool
fw_only: Optional[Callable] # Not cached, only used in pre_compile
maybe_subclass_meta: Optional[SubclassMeta]
num_fw_outs_saved_for_bw: Optional[int]
def pre_compile(
self,
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
):
(new_flat_fn, new_flat_args, subclass_meta) = aot_dispatch_subclass(
flat_fn,
flat_args,
is_joint_structure=self.trace_joint,
meta=fw_metadata,
fw_only=self.fw_only, # type: ignore[arg-type]
)
self.maybe_subclass_meta = subclass_meta
return new_flat_fn, new_flat_args, fw_metadata
def post_compile(
self,
compiled_fn,
_aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
if self.maybe_subclass_meta is None:
return compiled_fn
subclass_metas = runtime_metadata.subclass_fw_graph_out_meta
@wraps(compiled_fn)
def inner_fn(args: List[Any]):
unwrapped_args = runtime_unwrap_tensor_subclasses(
args,
subclass_metas=runtime_metadata.subclass_inp_meta,
append_symints=True,
)
args.clear()
# expectation: runtime_fn is a boxed fn
unwrapped_outs = compiled_fn(unwrapped_args)
wrapped_outs = wrap_tensor_subclasses(
unwrapped_outs,
subclass_metas=subclass_metas,
num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw,
is_runtime=True,
included_subclass_symints=True,
)
return wrapped_outs
# box it
inner_fn._boxed_call = True # type: ignore[attr-defined]
return inner_fn
@dataclass
class EffectTokensWrapper(CompilerWrapper):
def post_compile(
self,
compiled_fn,
_aot_config,
*,
runtime_metadata: ViewAndMutationMeta,
):
num_tokens = len(runtime_metadata.tokens)
@wraps(compiled_fn)
def inner_fn(args: List[Any]):
if num_tokens > 0:
# Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
old_args = args
args = [*([None] * num_tokens), *args]
old_args.clear()
outs = compiled_fn(args)
# Inductor cache DummyModule can return None
if outs is None:
return None
# Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
return outs[num_tokens:] if num_tokens != 0 else outs
# box it
inner_fn._boxed_call = True # type: ignore[attr-defined]
return inner_fn
# MOTIVATION:
#
# When tracing functions for future execution, one must be careful not to pass
# in the same input tensor multiple times (e.g., f(x, x), as this can result
# in graphs that are ONLY valid if you later pass a new tensor in exactly the
# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct
# tensors that alias each other is a different situation that is covered by
# aot_dispatch_deduplicated_autograd). Here are two examples:
#
# (1) Suppose you have a function:
#
# def f(x, y):
# return x + y
#
# If you make_fx(f)(x, x), you will trace out:
#
# def f(x, y):
# return y + y
#
# Oops!
#
# (2) For most tensors x and y, you can compute f's gradient with respect to
# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However,
# if x is y, you will trace out a program that gets incorrect gradients:
#
# >>> x = torch.randn(1, requires_grad=True)
# >>> torch.autograd.grad(x + x, (x, x))
# (tensor([2.]), tensor([2.]))
#
# In other words, the gradient is double-counted. Deduplicating the arguments
# gives you an appropriate gradient:
#
# >>> y = torch.randn(1, requires_grad=True)
# >>> torch.autograd.grad(x + y, (x, y))
# (tensor([1.]), tensor([1.]))
#
# HOW TO DEDUPLICATE:
#
# There are a few strategies, in order of preference:
#
# 1. For every duplicate argument to the function, detach it into
# a separate leaf tensor, so that it is no longer duplicated.
#
# PRO: The resulting compiled graph works for any configuration
# of duplicated arguments.
#
# CON: It does not (naively) work if you mutate the metadata of inputs:
#
# def f(x, y):
# x.transpose_(0, 1)
# y.transpose_(0, 2)
#
# x = torch.randn(2, 3, 4)
# f(x, x)
#
# The ordering of the transposes inside f dictates whether or not
# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute
# what metadata mutations should get applied to each input; you need to
# assume they aren't duplicates (what we do today) or preserve
# the original metadata mutations exactly in order, so that they work
# for any duplicate configuration.
#
# CON: It does not (naively) work if you mutate the data of inputs.
# In particular, leaf tensors that require grad cannot be mutated,
# this makes it impossible to differentiate with respect to the original
# base.
#
# 2. For every duplicate argument to the function, remove it, so it is
# no longer part of the "true" signature:
#
# PRO: Implemented naively, it still works for metadata/data mutation.
#
# CON: The resulting compiled graph is duplicate-specialized: it only
# works if future calls duplicate arguments in exactly the same way.
# Horribly, Dynamo doesn't guard on this at the moment. But even if
# it did, you could still end up recompiling a bunch of each duplicate.
#
# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if
# Dynamo's guards are not enough. In practice, this seems to cover
# everything.
#
@dataclass
class AOTDedupeWrapper(CompilerWrapper):
keep_arg_mask: List[bool] = field(default_factory=list)
add_dupe_map: List[int] = field(default_factory=list)
old_input_metadata: List[InputAliasInfo] = field(default_factory=list)
needs_post_compile: bool = True
# NB: Hot path, avoid set lookups here
# TODO: Can avoid the zip here too, probably
def remove_dupe_args(self, args):
return [t for t, keep in zip(args, self.keep_arg_mask) if keep]
def add_dupe_args(self, args):
return [args[i] for i in self.add_dupe_map]
def pre_compile(
self,
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
# Use information about whether or not flat_fn mutates its arguments
# or not to handle dupe args
# Strategy 1: For any input that is not mutated, we can leafify it if we
# need to remove a duplicate.
leaf_flat_args = []
args_set = set()
ok = True
for i, a in enumerate(flat_args):
if not isinstance(a, torch.Tensor):
leaf_flat_args.append(a)
elif a not in args_set:
args_set.add(a)
leaf_flat_args.append(a)
elif (
not fw_metadata.input_info[i].mutates_data
and not fw_metadata.input_info[i].mutates_metadata
):
leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
else:
ok = False
break
if ok:
self.needs_post_compile = False
return flat_fn, leaf_flat_args, fw_metadata
if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
raise RuntimeError(
"""\
Encountered duplicate inputs that are mutated in the graph, but at least one input/output
to the graph is a tensor subclass. This is not supported today. You can try to
remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
)
# export path: ban duplicate inputs for now, add later if requested.
if aot_config.is_export:
raise RuntimeError(
f"""\
Encountered duplicated inputs that are mutated in the graph you are trying to export.
This functionality is currently not supported. If needed, please file a github issue.
fw_metadata={str(fw_metadata)}
"""
)
# Strategy 2: Duplicate specialize.
#
# In Haskell types, suppose you have:
#
# add_dupe_args :: DedupedArgs -> Args
# remove_dupe_args :: Args -> DedupedArgs
#
# compiler_fn
# :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
# deped_compiler_fn
# :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
#
# Then the code below can be written in point-free style as:
#
# deduped_compiler_fn f a c =
# compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
#
# Suppose you have:
#
# [a, b, a, c]
#
# We want:
#
# remove_dupe_args([a, b, a, c]) == [a, b, c]
# add_dupe_args([a, b, c]) == [a, b, a, c]
#
# This is done via (respectively):
#
# seen_args = {a: 0, b: 1, c: 2}
# enumerate(add_dupe_map) = [ # how to get args from the deduped list
# (0, 0),
# (1, 1),
# (2, 0),
# (3, 2),
# ]
# keep_arg_mask = [True, True, False, True]
seen_args: Dict[Tensor, int] = {}
# Implicitly map duped arg position (list index) to de-duped arg position
keep_arg_mask: List[bool] = []
add_dupe_map: List[int] = []
duped_arg_len = len(flat_args)
j = 0 # index into deduped_flat_args
for t in flat_args:
if isinstance(t, torch.Tensor):
if t in seen_args:
keep_arg_mask.append(False)
add_dupe_map.append(seen_args[t])
continue
seen_args[t] = j
keep_arg_mask.append(True)
add_dupe_map.append(j)
j += 1
assert (
len(add_dupe_map) == duped_arg_len
), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
self.keep_arg_mask = keep_arg_mask
self.add_dupe_map = add_dupe_map
deduped_flat_args = self.remove_dupe_args(flat_args)
# Update our input metadata to remove duped input metadata.
updated_fw_metadata = remove_dupe_metadata(
fw_metadata, keep_arg_mask, add_dupe_map
)
if (
tracing_context := TracingContext.try_get()
and aot_config.aot_autograd_arg_pos_to_source
):
# TODO(voz): This structure is 1:1, we could consider an alternate structure like
# kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
# which feels like needless complexity for a tiny bit of efficiency at this point.
for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(
zip(add_dupe_map, keep_arg_mask)
):
if not keep_arg:
dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[
dupe_arg_pos
]
kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[
kept_pos
]
tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined]
DuplicateInputs(kept_arg_source, dupe_arg_source)
)
@wraps(flat_fn)
def wrapped_flat_fn(*args):
return flat_fn(*self.add_dupe_args(args))
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*deduped_flat_args)
assert (
ref_fw_metadata == updated_fw_metadata
), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata
def post_compile(
self,
compiled_fn,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
if not self.needs_post_compile:
return compiled_fn
@wraps(compiled_fn)
def wrapped_compiled_fn(args: List[Any]):
deduped_args = self.remove_dupe_args(args)
args.clear()
return compiled_fn(deduped_args)
wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined]
# This can be uncommented when we properly guard for duplicates,
# but right now we must not do it.
# if not config.debug_assert:
# return wrapped_compiled_fn
@wraps(wrapped_compiled_fn)
def debugged_compiled_fn(args):
# Test that the computed remove/add arg functions are an inverse
new_args = self.add_dupe_args(self.remove_dupe_args(args))
seen: Dict[Any, None] = {}
for i, (x, y) in enumerate(zip(new_args, args)):
seen[y] = None
assert x is y, format_guard_bug_msg(
aot_config,
f"{describe_input(i, aot_config)} would be a duplicate of "
f"{describe_input(self.add_dupe_map[i], aot_config)}",
)
# This is only an error if there is metadata mutation on both of
# the duped arguments; in this case, we need to know what order
# the metadata mutation applies in. You'll get the correct result
# otherwise, because a graph that assumes distinct inputs works if
# you dupe the inputs (the gradient contributions from each input
# will get summed up appropriately.)
#
# TODO: work out how to setup this assert correctly
"""
assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
f"there would be {unique_args} distinct arguments"
)
"""
return wrapped_compiled_fn(args)
debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined]
return debugged_compiled_fn
# This layer handles the situation where you have two inputs that alias each other,
# and one of the inputs is mutated.
# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
#
# pre-condition: AOTDedupWrapper has already run.
# (This function will in theory work if there are duplicate args.
# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
# would cause us to hit that path more frequently).
@dataclass
class AOTSyntheticBaseWrapper(CompilerWrapper):
# Currently, the only reason we need to plumb this bool is because
# the synthetic base code prohibits more cases in the autograd case than the inference case.
trace_joint: bool # TODO: refactor trace_joint
needs_post_compile: bool = True
aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list)
def pre_compile(
self,
flat_fn,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
is_inference = not self.trace_joint
flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
aot_config,
flat_args,
fw_metadata.input_info,
is_inference=is_inference,
)
# Happy path: we don't need synthetic bases
if synthetic_base_info is None:
self.needs_post_compile = False
return flat_fn, flat_args, fw_metadata
# export path: ban synthetic bases for now, add later if requested.
if requires_subclass_dispatch(flat_args, fw_metadata):
raise RuntimeError(
"""\
Encountered aliased inputs that are mutated in the graph, but at least one input/output
to the graph is a tensor subclass. This is not supported today. You can try to
remove the aliasing yourself as a workaround, or otherwise file an issue on github."""
)
if aot_config.is_export:
raise RuntimeError(
f"""\
Encountered aliased inputs that are mutated in the graph you are trying to export.
This functionality is currently not supported. If needed, please file a github issue.
synthetic_base_info={str(synthetic_base_info)}
fw_metadata={str(fw_metadata)}
"""
)
assert len(fw_metadata.input_info) == len(synthetic_base_info)
# Update our forward metadata to take synthetic bases into account
(
fw_metadata_updated,
aliased_arg_idx_with_metadata_mutations,
) = create_synthetic_base_metadata(
fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases
)
# Save old input args for post-compile
self.old_input_info = fw_metadata.input_info
self.aliased_arg_idx_with_metadata_mutations = (
aliased_arg_idx_with_metadata_mutations
)
num_aliased_args_with_metadata_mutations = len(
aliased_arg_idx_with_metadata_mutations
)
replay_views = config.view_replay_for_aliased_outputs
def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
f_args_inner = []
for inner_idx_or_tuple in synthetic_base_info:
if isinstance(inner_idx_or_tuple, int):
f_args_inner.append(primals[inner_idx_or_tuple])
else:
inner_base_idx, view_tensor = inner_idx_or_tuple
base = primals[inner_base_idx]
view_arg = gen_alias_from_base(
base,
view_tensor,
view_tensor.requires_grad,
replay_views=replay_views,
)
f_args_inner.append(view_arg)
return f_args_inner
@wraps(flat_fn)
def wrapped_flat_fn(*args):
unpacked_args = _unpack_synthetic_bases(args)
# This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
# is to relieve the downstream logic from having to reason about mutations on inputs that alias
# each other, by replacing aliased inputs with a synthetic base.
# One area where this breaks down a bit however is if one of those aliased inputs
# experienced a metadata mutation.
# We are now obligated to reapply the metadata mutation directly to the user's input;
# it isn't enough to apply mutations back to the synthetic base in the downstream logic.
#
# The way we handle this is by pretending that those aliased inputs that experience metadata mutations
# are additional outputs in the user's forward function.
# The downstream logic will just treat these as "user outputs that alias inputs".
# However, we will manually grab them at runtime here, use them to reapply the metadata mutation
# to the user inputs, and not return them to the user.
aliased_args_with_metadata_mutations = [
x
for i, x in enumerate(unpacked_args)
if i in self.aliased_arg_idx_with_metadata_mutations
]
if len(aliased_args_with_metadata_mutations) > 0:
return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
else:
return flat_fn(*unpacked_args)
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*flat_args_with_synthetic_bases)
assert ref_fw_metadata == fw_metadata_updated, (
f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, "
f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}"
)
return (
wrapped_flat_fn,
flat_args_with_synthetic_bases,
fw_metadata_updated,
)
def post_compile(
self,
compiled_fn,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
if not self.needs_post_compile:
return compiled_fn
is_inference = not self.trace_joint
@wraps(compiled_fn)
def wrapped_compiled_fn(args):
args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
aot_config, args, self.old_input_info, is_inference=is_inference
)
assert synthetic_base_info is not None
aliased_args_w_metadata_mutations = [
args[i] for i in self.aliased_arg_idx_with_metadata_mutations
]
num_aliased_args_with_metadata_mutations = len(
aliased_args_w_metadata_mutations
)
args.clear()
outs = compiled_fn(args_with_synthetic_bases)
if num_aliased_args_with_metadata_mutations > 0:
# This code does not handle **all** input metadata mutations.
# Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
# (which only happens if at least one aliased input experienced a data mutation).
# e.g:
# def f(a, b):
# a.mul_(2)
# b.t_(1, 0)
# f(x.view(2, 2), x.view(2, 2))
mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
user_outs = outs[:-num_aliased_args_with_metadata_mutations]
for inp, mutated_inp in zip(
aliased_args_w_metadata_mutations, mutated_metadata_inps
):
inp.as_strided_(
mutated_inp.size(),
mutated_inp.stride(),
mutated_inp.storage_offset(),
)
return user_outs
return outs
return wrapped_compiled_fn
# Note [Handling mutations on an input that aliases other inputs]
# The easiest example to show-case this edge case is here:
#
# def f(a, b):
# a.mul_(2)
# out = a + b
# return out
# b = torch.ones(...)
# a = b.view(-1)
# f(a, b)
#
# In this situation, if a and b happened to be aliased, we need to trace something different!
# Suppose we had b = a.view(-1)
# (In this case, that means that `a._base is b`)
#
# We need to ensure that the aliasing relationship between a and b is preserved.
# We do that detecting the specific situation above (mutate an input that aliases another input),
# and when we do that, we create a synthetic base argument. Then inside of the traced forward,
# we regenerate a and b off of that base.
# The complete example of the transformed function looks like this:
#
# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views
# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph
# def traced_forward(base):
# a = base.as_strided(...)
# b = base.as_strided(...)
# a_updated = a.mul(2)
# base_updated = torch.as_strided_scatter(base, a_updated, ...)
# b_updated = base_updated.as_strided(...)
# out = a_updated + b_updated
# return a_updated, out
#
# def compiled_fn(a, b):
# // we detect that a is the "differentiable base" here
# base = a
# // In other situations, we might do either:
# // (1) a and b are both views off of some larger differentiable base
# // assert a._base is b._base and a._base is not None
# // base = a._base
# // (2) a and b both don't require gradients. Create a base from the storage
# // assert a._base is None and b._base is None
# // base = torch.Tensor(a.storage())
# a_updated, out = traced_forward(base)
# a.copy_(a_updated)
# return out
#
# This function:
# (1) Merges input views into a synthetic base argument, when any of those input views are mutated
# (2) Returns metadata telling the autograd.Function how to modify their arguments properly,
# to respect the new calling convention.
#
# The calling convention is as follows.
# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.
# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],
# Where the ordering of the bases is determined from the ordering of the original view args.
# baseA will come before baseB if the earliest original argument coming from baseA
# showed up earlier in the argument list than the earliest original argument coming from baseB.
#
# Example, given some tensors a, b, c, d
# call site:
# f(a, c.view(-1), b.view(-1), b, c, d)
# Modified argument list:
# c_base comes first because the first c view came earlier in arg list than the first b view
# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases
# b_base = torch.Tensor(b.storage())
# c_base = torch.Tensor(c.storage())
# f(c_base, b_base, a, d)
def merge_view_inputs(
aot_config: AOTConfig,
fwd_inputs: List[Any],
mutated_input_info: List[InputAliasInfo],
*,
# The autograd case currently has more restrictions than the inference case.
is_inference: bool,
) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
def _are_differentiable_views(view1, view2):
if view1 is view2:
return True
if view1._base is None and view2._base is None:
return False
if view1._base is view2._base or view1._base is view2 or view1 is view2._base:
return True
return False
def _same_dtype_views(view1, view2):
if view1.dtype != view2.dtype:
return False
if view1._base is not None and view1.dtype != view1._base.dtype:
return False
if view2._base is not None and view2.dtype != view2._base.dtype:
return False
return True
assert len(fwd_inputs) == len(mutated_input_info)
if not [info for info in mutated_input_info if info.mutates_data]:
# Return early when there are no mutations.
return fwd_inputs, None
storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
base_args = []
other_args = []
for i, inpt in enumerate(fwd_inputs):
if isinstance(inpt, Tensor):
storage_ref = StorageWeakRef(inpt.untyped_storage())
storage_ref_to_idx[storage_ref].append(i)
else:
other_args.append(inpt)
# Note [Synthetic Base Info Metadata]
# This list contains metadata that tells you what the i'th argument in the inner calling convention should be.
# It's either:
# - another int (corresponding to the index in the argument list of the element from the outer calling convention)
# - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
# idx corresponds to which synthetic base from the outer calling context to view
inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
for aliased_input_indices in storage_ref_to_idx.values():
if len(aliased_input_indices) <= 1 or not any(
# We only care about mutations that affect all aliases,
# so metadata mutations on an input doesn't require us to do synthetic base handling.
mutated_input_info[inpt_idx].mutates_data
for inpt_idx in aliased_input_indices
):
other_args.extend(
fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
)
continue
# Here, we attempt to do a more complicated check to detect false aliasing
# (e.g. if all the tensors have the same storage, but don't actually overlap)
# In theory, we could have a large group of tensors that all share storages, where only *some* of them
# have overlapping memory.
# I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair
# of tensors in the current group that shares a storage is non-overlapping.
aliased_input_indices_no_false_sharing = compute_overlapping_inputs(
aot_config, fwd_inputs, aliased_input_indices
)
if len(aliased_input_indices_no_false_sharing) <= 1:
other_args.extend(
fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
)
continue
# We detected an input that was mutated, AND aliases with another input.
# we need to replace this set of aliased inputs with a single synthetic base.
# For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
# and error out. We can fix them later.
# These checks are transitive, so we don't need to check every pair.
for idx1, idx2 in zip(
aliased_input_indices, aliased_input_indices[1:], strict=False
):
view1 = fwd_inputs[idx1]
view2 = fwd_inputs[idx2]
# The "inputs that are aliased but have different differentiable bases" case
# is more complicated and hopefully pretty rare. Not currently handled.
if not is_inference:
assert _are_differentiable_views(
view1, view2
), "aot_autograd() does not yet handle non-differentiable view input mutations."
# Regenerating views when reinterpreting complex / real tensors seems non-trivial,
# not handling for now
assert _same_dtype_views(
view1, view2
), "aot_autograd() does not yet handle input mutations on views with different dtypes."
non_none_bases = [
fwd_inputs[i]._base
for i in aliased_input_indices
if fwd_inputs[i]._base is not None
]
aliases_with_none_bases = [
fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None
]
if len(non_none_bases) == 0:
# Case where none of the aliases have a ._base
# we generate a synthetic base without gradients, and generate views off of it
# We hit this case when we have input tensors to the graph that share a storage,
# but do not have a ._base field.
# Wondering when we hit this case?
# The _base field simply says that autograd knows about the aliasing relationship,
# but sometimes we create tensors which are aliased out of the same storage but guaranteed
# to be disjoint. In these cases, we will skip setting up the _base relationship
# for performance reasons (because the fact that the tensors share the same storage
# is unobservable unless you (1) do naughty things with resize_/as_strided
# or (2) look at the storage--as we are doing here.)
# One particular example of this is optimizer steps on the LSTM module:
# LSTM parameters are packed into a contiguous storage for efficiency reasons when
# calling cuDNN kernels, so when these parameters get passed to the optimizer we will
# find they share the same storage, but do not have _base set since they are all disjoint.
#
# NOTE: There is one case where this is unsafe:
# torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily
# the same shape as the "actual" base that the tensor came from.
# For the most part this is fine, because we always use as_strided()
# to generate the original aliased inputs again.
# If we were to use view-replay though, this could cause the aliased views
# to have incorrect sizes.
example_idx = aliased_input_indices[0]
example_alias = fwd_inputs[example_idx]
# Note that this function is re-used at both trace time and runtime.
# At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.
synthetic_base = torch.empty(
(0,), dtype=example_alias.dtype, device=example_alias.device
)
# We don't actually have a convenient way of going from storage -> tensor,
# So using set_() here (we suffer some minor overhead, but this case is rare).
synthetic_base.set_(example_alias.untyped_storage())
else:
# Case where all of the aliases require gradients, and have the same _base.
synthetic_base = non_none_bases[0]
for other_base in non_none_bases[1:]:
assert (
other_base is synthetic_base
), "aot_autograd() does not yet handle non-differentiable view input mutations."
for alias in aliases_with_none_bases:
assert (
alias is synthetic_base
), "aot_autograd() does not yet handle non-differentiable view input mutations."
base_args.append(synthetic_base)
for curr_view_idx in aliased_input_indices:
curr_view = fwd_inputs[curr_view_idx]
base_idx = len(base_args) - 1
# We store just enough info here so that we can regenerate the view later.
# Regeneration: curr_view._view_func(args[base_idx])
inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view)
if len(base_args) == 0:
assert len(other_args) == len(fwd_inputs)
# If no synthetic bases are necessary, just return the original inputs.
return fwd_inputs, None
else:
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
def make_hashable(arg):
if isinstance(arg, torch.SymInt):
# Since only nested SymInt objects can be hashed, we wrap them with
# SymIntEqByExpr, which is a hashable wrapper of SymInts.
return SymIntEqByExpr(arg)
return arg
# Otherwise, return:
# (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
# (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
# We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
args_to_functionalization = base_args + other_args
arg_to_old_idx_map = {
make_hashable(arg): i for (i, arg) in enumerate(fwd_inputs)
}
for i, other_arg in enumerate(other_args):
new_idx = len(base_args) + i
old_idx = arg_to_old_idx_map[make_hashable(other_arg)]
inner_calling_convention_meta[old_idx] = new_idx
# post process into a list
post_processed_calling_convention_meta: List[
Union[int, Tuple[int, torch.Tensor]]
] = [-1 for _ in range(len(inner_calling_convention_meta))]
for k, v in inner_calling_convention_meta.items():
post_processed_calling_convention_meta[k] = v
# Quick assert: every argument in the inner calling convention should be accounted for.
for x in post_processed_calling_convention_meta:
assert x != -1
return args_to_functionalization, post_processed_calling_convention_meta
@dataclass
class AutogradLazyBackwardCompileInfo:
bw_module: Callable
placeholder_list: List[Any]
saved_context: Optional[TracingContext]
saved_compile_context: Optional[CompileContext]
# This is wrapped in a class just for namespacing purposes
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
class AOTDispatchAutograd:
@staticmethod
def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta]):
if not isinstance(x, torch.Tensor):
return x, [x]
if isinstance(x, FakeTensor):
if not x.is_contiguous(memory_format=meta.memory_format):
x = x.contiguous(memory_format=meta.memory_format)
return x, [x]
expected_type: Optional[type] = torch.Tensor
expected_meta = None
if isinstance(meta, SubclassCreationMeta):
expected_type = meta.original_subclass_type
expected_meta = meta.meta
runtime_type = type(x)
runtime_meta = None
runtime_subclass_keys: Sequence[str] = []
if is_traceable_wrapper_subclass(x):
runtime_subclass_keys, runtime_meta = x.__tensor_flatten__()
def maybe_coerce(x):
same_type: bool = expected_type == runtime_type
same_meta: bool = expected_meta == runtime_meta
if same_type and same_meta:
return x
if not hasattr(x, "__coerce_same_metadata_as_tangent__"):
return None
if same_type:
# Backward Compatibility, as some Subclass impls can have original 1-arg function.
return x.__coerce_same_metadata_as_tangent__(expected_meta)
return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type)
# Coerce to expected type and metadata
orig_x = x
x = maybe_coerce(x)
if x is None:
raise RuntimeError(
f"""
During the backward, we encountered a tensor subclass where we guessed its
metadata incorrectly.
Expected metadata: {str(expected_meta)}, expected type: {str(expected_type)}
Runtime metadata: {str(runtime_meta)}, runtime type: {str(runtime_type)}
shape: {str(orig_x.shape)}
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
"""
)
# Coerce to expected memory format
if not x.is_contiguous(memory_format=meta.memory_format):
x = x.contiguous(memory_format=meta.memory_format)
if not is_traceable_wrapper_subclass(x):
return x, [x]
assert isinstance(meta, SubclassCreationMeta)
if orig_x is not x:
runtime_subclass_keys = x.__tensor_flatten__()[0]
assert len(meta.attrs) == len(runtime_subclass_keys)
leaves = []
for i, (attr, attr_meta) in enumerate(meta.attrs.items()):
elem = getattr(x, attr)
new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent(
elem, attr_meta
)
if new_elem is not elem:
setattr(x, attr, new_elem)
leaves.extend(elem_leaves)
return x, leaves
@staticmethod
def post_compile(
compiled_fw_func, # fw_module after compilation + wrappers
compiled_bw_func, # bw_module after compilation + wrappers
maybe_subclass_meta: Optional[SubclassMeta],
num_symints_saved_for_bw_: int,
backward_state_indices: List[int],
disable_amp: bool,
indices_of_inps_to_detach: List[int],
lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta, # runtime metadata
try_save_cache_entry: Optional[Callable], # Save cache entry after compilation
):
class CompiledFunction(torch.autograd.Function):
compiled_fw = compiled_fw_func
compiled_bw = compiled_bw_func
metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment]
maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta
num_symints_saved_for_bw = num_symints_saved_for_bw_
_compiled_autograd_should_lift = False
_aot_id = aot_config.aot_id
_lazy_backward_info = lazy_backward_info
@staticmethod
def _compiled_autograd_key(ctx):
return (ctx._autograd_function_id, *ctx.symints)
@staticmethod
def forward(ctx, *deduped_flat_tensor_args):
args = deduped_flat_tensor_args
if backward_state_indices:
bw_state = args[backward_state_indices[0]]
assert isinstance(bw_state, BackwardState)
ctx._compiled_autograd_backward_state = bw_state
# There is a pretty complicated calling convention around what the compiled fw returns.
# The full list of outputs and their relative order is:
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
# of the original view, and not the synthetic base
# - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last
# in the fw output order.
fw_outs = call_func_at_runtime_with_args(
CompiledFunction.compiled_fw,
args,
disable_amp=disable_amp,
)
num_outputs = CompiledFunction.metadata.num_outputs
num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
num_mutated_runtime_inps = (
CompiledFunction.metadata.num_mutated_inp_runtime_indices
)
num_forward_returns = CompiledFunction.metadata.num_forward_returns
# Partitioners must put symint arguments at the end separate from tensor arguments
tensors_saved_for_backwards = fw_outs[
CompiledFunction.metadata.tensors_saved_for_backwards_slice
]
assert all(
isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards
)
# See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(
*(
x.detach() if x._is_view() else x
for x in tensors_saved_for_backwards
)
)
symint_outs = fw_outs[
CompiledFunction.metadata.symints_saved_for_backwards_slice
]
assert all(
isinstance(x, (int, float, torch.SymInt, torch.SymFloat))
for x in symint_outs
), str([type(x) for x in symint_outs])
ctx.symints = symint_outs
raw_returns = fw_outs[0:num_forward_returns]
# Wrap all autograd.Function.forward() outputs that are aliases
# so that autograd.Function doesn't treat them as tensors
if num_mutated_runtime_inps > 0:
for i, idx in enumerate(
CompiledFunction.metadata.mutated_inp_runtime_indices
):
# We could make this faster by only looping over inputs with metadata-only mutations
# (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many.
info = CompiledFunction.metadata.input_info[idx]
if info.mutates_metadata and not info.mutates_data:
raw_return_idx = i
raw_returns[raw_return_idx] = TensorAlias(
raw_returns[raw_return_idx]
)
if config.debug_assert:
user_mutated_inputs_raw = raw_returns[
0:num_mutated_runtime_inps
]
mut_inp_infos = [
x
for x in CompiledFunction.metadata.input_info
if x.mutates_data or x.mutates_metadata
]
assert len(user_mutated_inputs_raw) == len(mut_inp_infos)
if CompiledFunction.metadata.num_unsafe_view_outputs > 0:
for idx in CompiledFunction.metadata.unsafe_view_out_indices:
raw_return_idx = num_mutated_runtime_inps + idx
o = raw_returns[raw_return_idx]
raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view(
o, o.shape
)
if num_outputs_aliased > 0:
for idx in CompiledFunction.metadata.aliased_out_indices:
raw_return_idx = num_mutated_runtime_inps + idx
raw_returns[raw_return_idx] = TensorAlias(
raw_returns[raw_return_idx]
)
if config.debug_assert:
intermediates_raw = raw_returns[
num_mutated_runtime_inps + num_outputs :
]
assert not any(
isinstance(x, TensorAlias) for x in intermediates_raw
)
# invariant: intermediate bases always require gradients, so we don't have to
# consider marking them as non-differentiable.
raw_returns_not_including_intermediate_bases = raw_returns[
: num_mutated_runtime_inps + num_outputs
]
raw_returns_meta = [
x
for x in CompiledFunction.metadata.input_info
if x.mutation_type == MutationType.MUTATED_OUT_GRAPH
] + CompiledFunction.metadata.output_info
fw_outs_not_requiring_grad = [
x
for (i, x) in enumerate(
raw_returns_not_including_intermediate_bases
)
if isinstance(x, torch.Tensor)
and not raw_returns_meta[i].requires_grad
]
ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
ctx._materialize_non_diff_grads = False
return tuple(raw_returns)
@staticmethod
def backward(ctx, *flat_args):
all_args = CompiledFunction._backward_prologue(ctx, *flat_args)
def impl_fn(double_ctx=None):
out = CompiledFunction._backward_impl(ctx, all_args)
return CompiledFunction._backward_epilogue(ctx, out)
needs_grad = torch.is_grad_enabled() and any(
t.requires_grad for t in all_args if isinstance(t, torch.Tensor)
)
if needs_grad:
# double backward
return CompiledFunction._double_backward(ctx, impl_fn, all_args)
else:
return impl_fn()
@staticmethod
def _double_backward(ctx, impl_fn, all_args):
# Ensure that the graph is connected, and error if double backward is performed.
# See comment for why once_differentiable is not sufficient:
# https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
class CompiledFunctionBackward(torch.autograd.Function):
# CompiledFunctionBackward is not yet supported in dynamo skipfiles
_compiled_autograd_should_lift = False
_aot_id = aot_config.aot_id
@staticmethod
def forward(double_ctx, *unused_args):
return impl_fn(double_ctx)
@staticmethod
def backward(double_ctx, *args):
raise RuntimeError(
"torch.compile with aot_autograd does not currently support double backward"
)
CompiledFunctionBackward._compiled_autograd_key = ( # type: ignore[method-assign]
CompiledFunction._compiled_autograd_key
)
return CompiledFunctionBackward.apply(*all_args)
@staticmethod
def _raise_if_functorch_active():
# not ideal but prevent the user from seeing a nasty traceback - See #138422
stack = torch._C._functorch.peek_interpreter_stack()
torch._check(
stack is None,
lambda: (
"It looks like you're trying to call a compiled backward function within vmap/grad/vjp, "
"which isn't supported. Try wrapping vmap inside torch.compile, or skip compiling the "
"backward function."
),
)
@staticmethod
def _backward_prologue(ctx, *flat_args):
# Calling convention: we expect a grad_out passed to the backward:
# - for every output of the fw that does *not* alias an input or graph intermediate
# - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)
# - for every graph intermediate that we need to use to generate an output later.
# The other outputs in the autograd.Function.forward that do *not* show up in the backward include:
# - outputs that alias inputs or graph intermediates
# - updated inputs due to metadata-only mutations.
# We need to return them in the forward, but ensure that they all do not get gradients in the backward,
# and we filter them out here before passing the remaining grad_outputs into the compiled backward.
CompiledFunction._raise_if_functorch_active()
num_intermediate_bases = (
CompiledFunction.metadata.num_intermediate_bases
)
num_mutated_runtime_inps = (
CompiledFunction.metadata.num_mutated_inp_runtime_indices
)
expected_grad_outs = (
CompiledFunction.metadata.num_outputs
+ num_mutated_runtime_inps
+ num_intermediate_bases
)
deterministic = CompiledFunction.metadata.deterministic
global_deterministic = torch.are_deterministic_algorithms_enabled()
if deterministic is not None:
torch._check(
not (not deterministic and global_deterministic),
lambda: (
"This compiled backward function is being run with "
"torch.use_deterministic_algorithms(True), "
"but it was previously generated during the forward function while "
"torch.use_deterministic_algorithms(False) was set."
),
)
assert len(flat_args) == expected_grad_outs
out_info = CompiledFunction.metadata.output_info
inp_tangents, out_tangents, intermediate_base_tangents = (
flat_args[:num_mutated_runtime_inps],
flat_args[
num_mutated_runtime_inps : num_mutated_runtime_inps
+ CompiledFunction.metadata.num_outputs
],
flat_args[
num_mutated_runtime_inps
+ CompiledFunction.metadata.num_outputs :
],
)
# input_info contains info on *every* input,
# But in the backward(), we are only given grad outputs for every mutated input
# We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
input_info = CompiledFunction.metadata.input_info
inp_tangents_filtered = [
x
for x, info_idx in zip(
inp_tangents,
CompiledFunction.metadata.mutated_inp_runtime_indices,
)
if input_info[info_idx].mutates_data
and input_info[info_idx].requires_grad
]
# We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
out_tangents_filtered = [
x
for x, info in zip(out_tangents, out_info)
if info.output_type
in [
OutputType.non_alias,
OutputType.unsafe_view_alias,
OutputType.custom_function_view,
]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
]
# intermediate bases always require gradients, and always participate in the backward graph.
flat_bw_args_with_grads = [
*inp_tangents_filtered,
*out_tangents_filtered,
*intermediate_base_tangents,
]
num_flat_bw_args_with_grads = len(flat_bw_args_with_grads)
# sanity asserts
# metadata_only_inps = [
# x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
# if not input_info[info_idx].mutates_data
# ]
# aliased_outputs = [
# x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
# assert all(x is None for x in metadata_only_inps)
# assert all(x is None for x in aliased_outputs)
# TODO: replace this with FunctionalizedRngRuntimeWrapper
rng_args = []
if CompiledFunction.metadata.is_rng_op_functionalized:
# Add the seed and offset to args
rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens
# - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first
# in the bw output order.
# Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls
# There are tests that count these calls, saving to var.
ctx_saved_tensors = ctx.saved_tensors
num_ctx_saved_tensors = len(ctx_saved_tensors)
all_args = [
*ctx.symints,
*ctx_saved_tensors,
*flat_bw_args_with_grads,
*bw_tokens,
*rng_args,
]
del ctx_saved_tensors
# Note: [AOTAutograd Backward Guards]
# During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
# Doing so requires us to "guess" about some of the metadata of our grad_outputs.
#
# In particular: if an output to the forward is a plain tensor or a subclass,
# its corresponding grad_output in the backward **may or may not** be
# a plain tensor or a subclass. The main cases are:
# (1) If an output is a plain tensor, its grad_out will also be a plain tensor,
# *unless* the output is used in some subclass compute later in the forward graph,
# which will cause its grad_output to become a subclass
# (2) If an output is a subclass, its grad_out will also be a subclass,
# *unless* the output of the forward did not actually participate in the gradient computation,
# in which case autograd will insert a plain tensor of zeros for the grad_output.
# We could avoid this case with `torch.autograd.Function.set_materialize_grads`,
# although this is not turned on today in AOTAutgrad and would require more work.
#
# Today, we make a guess on subclass-ness based on the above examples,
# and hard-error in the backward if we guessed wrong.
#
# In the future, we should add backward guards that would allow us to
# properly handle this case instead of erroring: we would need to retrace the backward graph,
# since we might produce an entirely different trace if our grad_outputs are subclass or not.
del flat_bw_args_with_grads
tangents_start_idx = (
len(all_args)
- num_flat_bw_args_with_grads
- len(rng_args)
- len(bw_tokens)
)
assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors
tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens)
# TODO: figure out how to refactor the backward properly
# so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
tangents = all_args[tangents_start_idx:tangents_end_idx]
if len(tangents) != len(
CompiledFunction.metadata.subclass_tangent_meta
):
raise RuntimeError(
"The grad inputs should be same number as forward output tangents"
)
flat_processed_tangents = list(
itertools.chain.from_iterable(
(
AOTDispatchAutograd.process_runtime_tangent(
t,
m,
)[1]
)
for t, m in zip(
tangents,
CompiledFunction.metadata.subclass_tangent_meta,
)
)
)
all_args = (
runtime_unwrap_tensor_subclasses(
all_args[:tangents_start_idx],
# SymInts that are inputs to the backward graph are
# already included in the "all_args" list.
# Any symints coming from tensor subclasses should always
# come from primals, and so they will show up as extra
# arguments to the forward graph, and they will be saved
# as activation in the backward graph.
append_symints=False,
)
+ flat_processed_tangents
+ runtime_unwrap_tensor_subclasses(
all_args[tangents_end_idx:],
append_symints=False,
)
)
else:
all_args = [
(
AOTDispatchAutograd.process_runtime_tangent(
t,
CompiledFunction.metadata.subclass_tangent_meta[
i - tangents_start_idx
],
)[0]
if (tangents_start_idx <= i < tangents_end_idx)
else t
)
for i, t in enumerate(all_args)
]
# Backward with forward inputs mutations is not supported in double backward.
if (
torch.is_grad_enabled()
and CompiledFunction.metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw
):
raise RuntimeError(
"aot_autograd does not support input mutations with requires_grad in backward for create_graph=True"
)
return all_args
@staticmethod
def _backward_impl(ctx, all_args):
if ctx._is_compiled_autograd_tracing():
if lazy_backward_info is None:
raise RuntimeError(
"""This compiled backward function was saved by AOTAutogradCache, which does not support
compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
)
bw_module = lazy_backward_info.bw_module
# For compiled autograd, run raw FX graph so that it can be inlined into the larger graph
symints = ctx._get_compiled_autograd_symints()
assert len(symints) == len(ctx.symints)
all_args[: len(symints)] = symints
if backward_state_indices:
assert ctx._compiled_autograd_backward_state.proxy is not None
all_args.append(ctx._compiled_autograd_backward_state)
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
return normalize_as_list(bw_module(*all_args))
assert (
not backward_state_indices
), "BackwardState requires CompiledAutograd"
ctx.maybe_clear_saved_tensors()
saved_tensors_use_once = (
not torch._C._autograd._get_current_graph_task_keep_graph()
)
if CompiledFunction.compiled_bw is None:
assert lazy_backward_info is not None
if not saved_tensors_use_once:
fw_metadata.bw_donated_idxs = []
# Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd`
if (
hasattr(lazy_backward_info, "saved_context")
and hasattr(lazy_backward_info.saved_context, "fw_metadata")
and hasattr(
lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr]
"bw_donated_idxs",
)
):
lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr]
[]
)
bw_module = lazy_backward_info.bw_module
placeholder_list = lazy_backward_info.placeholder_list
saved_context = lazy_backward_info.saved_context
saved_compile_context = lazy_backward_info.saved_compile_context
context = torch._C._DisableAutocast if disable_amp else nullcontext
metrics_context = get_metrics_context()
with tracing(saved_context), compile_context(
saved_compile_context
), context(), track_graph_compiling(
aot_config, "backward"
), metrics_context, dynamo_timed(
"backward._backward_impl",
phase_name="entire_backward_compile",
log_pt2_compile_event=True,
dynamo_compile_column_us="backward_cumulative_compile_time_us",
):
metrics_context.update_outer({"is_forward": False})
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
)
# Maybe save cache entry
if try_save_cache_entry is not None:
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,
aot_config,
)
if (
torch._functorch.config.donated_buffer
and not saved_tensors_use_once
and fw_metadata.bw_donated_idxs != []
):
torch._check(
False,
lambda: (
"This backward function was compiled with non-empty donated "
"buffers which requires create_graph=False and retain_graph=False. "
"Please keep backward(create_graph=False, retain_graph=False) "
"across all backward() function calls, or set "
"torch._functorch.config.donated_buffer=False to disable "
"donated buffer."
),
)
out = call_func_at_runtime_with_args(
CompiledFunction.compiled_bw,
all_args,
steal_args=True,
disable_amp=disable_amp,
)
return out
@staticmethod
def _backward_epilogue(ctx, out):
# Toss out the backward output tokens
num_bw_tokens = CompiledFunction.metadata.num_backward_tokens
if num_bw_tokens > 0:
out = out[:-num_bw_tokens]
# TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile
out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
CompiledFunction.metadata, out, offset_index=len(out) - 1
)
out = tuple(out)
# TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
assert (
CompiledFunction.maybe_subclass_metadata.grad_input_metas
is not None
)
outs_wrapped = wrap_tensor_subclasses(
out,
subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas,
included_subclass_symints=True,
is_runtime=True,
)
return outs_wrapped
return out
compiled_function = RuntimeWrapper(
indices_of_inps_to_detach=indices_of_inps_to_detach,
trace_joint=True,
disable_amp=disable_amp,
).post_compile(
CompiledFunction.apply,
aot_config,
runtime_metadata=fw_metadata,
)
return compiled_function
@dataclass
class DebugAssertWrapper(CompilerWrapper):
flat_requires_grad: List[Optional[bool]] = field(default_factory=list)
def post_compile(
self,
compiled_fn,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
):
@wraps(compiled_fn)
def debug_compiled_function(args: List[Any]):
# TODO: Check aliasing relationships
# TODO: Check strides for metadata mutation
# (NB: ideally, this logic is factored out of this function and
# you move these debug checks there)
# Check requires grad. Bad case is when we compiled with
# requires_grad = False, but input requires_grad = True
# (vice versa is OK; we compute a gradient and then throw
# it away when it hits the input.)
for i, a in enumerate(args):
can_require_grad = self.flat_requires_grad[i]
if can_require_grad is None:
assert not isinstance(a, Tensor)
elif not can_require_grad:
assert not a.requires_grad, format_guard_bug_msg(
aot_config,
f"{describe_input(i, aot_config)} would not require grad",
)
return compiled_fn(args)
return debug_compiled_function
def pre_compile(
wrappers: List[CompilerWrapper],
flat_fn: Callable,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
"""
Runs a sequence of wrappers on the given function and arguments.
Mutates wrappers in place.
"""
for wrapper in wrappers:
flat_fn, flat_args, fw_metadata = wrapper.pre_compile(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
return flat_fn, flat_args, fw_metadata
def post_compile(
wrappers: List[CompilerWrapper],
compiled_fn: Callable,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, ViewAndMutationMeta]:
"""
Runs a sequence of wrappers on the given function. Should be called after pre_compile()
"""
for wrapper in reversed(wrappers):
compiled_fn = wrapper.post_compile(
compiled_fn, aot_config, runtime_metadata=runtime_metadata
)
return compiled_fn, runtime_metadata
def make_runtime_safe(
fw_metadata: ViewAndMutationMeta,
maybe_subclass_meta: Optional[SubclassMeta],
):
"""
Calls make_runtime_safe on all ViewAndMutationMetas.
Modifies both arguments. Allows ViewAndMutationMetas to
be safely cached in AOTAutogradCache.
"""
fw_metadata.make_runtime_safe()
if maybe_subclass_meta is not None:
maybe_subclass_meta.fw_metadata.make_runtime_safe()
if maybe_subclass_meta.grad_input_metas:
for meta in maybe_subclass_meta.grad_input_metas:
if isinstance(meta, SubclassCreationMeta):
meta.make_runtime_safe()
|