1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628
|
# Owner(s): ["module: dynamo"]
import copy
import re
import unittest
from textwrap import dedent
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
expectedFailureDynamic,
rand_strided,
)
from torch._functorch.aot_autograd import _aot_export_function, create_functional_call
from torch._guards import CompileContext, StorageOverlap, TracingContext
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.profiler import profile
from torch.testing import FileCheck
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
def maybe_dupe_op(x):
y = x + 1
z = x + 2
if x.numel() < 5:
return y, y
else:
return y, z
def is_dynamic_shape_test(test_name):
return test_name.endswith("_dynamic_shapes")
aten = torch.ops.aten
lib = torch.library.Library("custom", "DEF") # noqa: TOR901
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
def test_LSTM(self):
# https://github.com/pytorch/torchdynamo/issues/1147
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.self_mod_model_lstm_lstm = torch.nn.LSTM(
64, 64, num_layers=2, bidirectional=True
)
def forward(self, permute: torch.Tensor):
self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute)
return (self_mod_model_lstm_lstm,)
mod = Repro()
aot_mod = torch.compile(mod, backend="aot_eager")
args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)]
args = [
rand_strided(sh, st, dt, dev).requires_grad_(rg)
for (sh, st, dt, dev, rg) in args
]
eager_result = mod(*args)
aot_result = aot_mod(*args)
self.assertTrue(torch._dynamo.testing.same(eager_result, aot_result))
def test_mutation(self):
# https://github.com/pytorch/torchdynamo/issues/1301
def fn(param, y):
prev_grad = torch.is_grad_enabled()
try:
torch.set_grad_enabled(False)
param.add_(y)
finally:
torch.set_grad_enabled(prev_grad)
return y
y = torch.randn(4)
x = torch.nn.Parameter(torch.randn(4))
aot_fn = torch.compile(fn, backend="aot_eager")
# This should not error: we mutated an autograd leaf under no_grad mode.
aot_fn(x, y)
def test_mutation1(self):
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
getitem = diagonal_chunked_attention_scores[
(
slice(None, None, None),
slice(None, None, None),
slice(None, 256, None),
slice(None, 257, None),
)
]
_stack0[
(
slice(None, None, None),
slice(None, -1, None),
slice(None, None, None),
slice(256, None, None),
)
] = getitem
view = _stack0.view(1, 12, 1024, 513)
return (view,)
x = torch.randn(torch.Size([12, 4, 256, 513]))
y = torch.randn(torch.Size([12, 3, 512, 513]))
aot_fn = torch.compile(fn, backend="aot_eager")
aot_fn(x, y)
def test_negative_testing_mutation(self):
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
getitem = diagonal_chunked_attention_scores[
(
slice(None, None, None),
slice(None, None, None),
slice(None, 256, None),
slice(None, 257, None),
)
]
_stack0 = torch.sin(_stack0)
_stack0[
(
slice(None, None, None),
slice(None, -1, None),
slice(None, None, None),
slice(256, None, None),
)
] = getitem
view = _stack0.view(1, 12, 1024, 513)
return (view,)
x = torch.randn(torch.Size([12, 4, 256, 513]))
y = torch.randn(torch.Size([12, 3, 512, 513]))
aot_fn = torch.compile(fn, backend="aot_eager")
aot_fn(x, y)
def test_negative_testing(self):
def fn(x, y):
return torch.sin(x).add_(y)
y = torch.randn(4)
x = torch.randn(4)
aot_fn = torch.compile(fn, backend="aot_eager")
aot_fn(x, y)
def test_call_fn_with_non_const_inputs_aot_safe(self):
class ModuleSpecialFwd(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=20, kernel_size=(5, 5)
)
def _conv_forward(self, x):
return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
def forward(self, x):
return self._conv_forward(x)
# Init mod
mod = ModuleSpecialFwd()
rx = torch.randn([3, 10, 10])
# Run it for real
real = mod(rx)
# Run it in export
graph, _ = torch._dynamo.export(mod)(rx)
# Run exported graph with AOT
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
aot_fn = torch.compile(graph, backend="aot_eager")
aot_fn(rx)
def test_call_fn_with_non_const_inputs_aot_unsafe(self):
class ModuleSpecialFwd(torch.nn.Module):
def _some_bad_fwd(self, param, y):
prev_grad = torch.is_grad_enabled()
try:
torch.set_grad_enabled(False)
param.add_(y)
finally:
torch.set_grad_enabled(prev_grad)
return y
def forward(self, x, y):
return self._some_bad_fwd(x, y)
# Init mod
mod = ModuleSpecialFwd()
x = torch.nn.Parameter(torch.randn(4))
y = torch.randn([4])
# Run it for real
real = mod(x, y)
# Run it in export
graph, _ = torch._dynamo.export(mod)(x, y)
# Assert equal
self.assertTrue(torch._dynamo.testing.same(real, graph(x, y)))
# Run exported graph with AOT
aot_fn = torch.compile(graph, backend="aot_eager")
# This should not error: we mutated an autograd leaf under no_grad mode.
aot_fn(x, y)
def test_call_fn_with_non_const_inputs_aot_unsafe_control_flow(self):
class ModuleSpecialFwd(torch.nn.Module):
def _some_bad_fwd(self, param, y):
if y[0][0] < 3:
return y + param
return param * y
def forward(self, x, y):
a = x * y
a = self._some_bad_fwd(a, a)
b = x + y
return a * b
# Init mod
mod = ModuleSpecialFwd()
x = torch.nn.Parameter(torch.randn([2, 2]))
y = torch.randn([2, 2])
# Run it for real
real = mod(x, y)
# Run it through optimize, with our capturing fn
gms = []
counter = CompileCounter()
def capturing_fn(gm, inputs):
nonlocal gms
gms.append(gm)
return counter(gm, inputs)
optimized_mod = torch.compile(mod, backend=capturing_fn)
# Assert equal
self.assertTrue(torch._dynamo.testing.same(real, optimized_mod(x, y)))
# Uncomment to reproduce commented out graphs below.
# for gm in gms:
# print("GM CODE", gm.code)
self.assertEqual(counter.frame_count, 4)
self.assertEqual(counter.op_count, 7)
# Graph 1
# def forward(self, x : torch.nn.parameter.Parameter, y : torch.Tensor):
# mul = x * y; x = y = None
# return (mul,)
# BREAK
# Graph 2
# def forward(self, y : torch.Tensor):
# getitem = y[0]; y = None
# getitem_1 = getitem[0]; getitem = None
# lt = getitem_1 < 3; getitem_1 = None
# return (lt,)
# BREAK
# Graph 3
# def forward(self, param : torch.Tensor, y : torch.Tensor):
# add = y + param; y = param = None
# return (add,)
# BREAK
# Graph 4
# def forward(self, _stack0 : torch.Tensor, x : torch.nn.parameter.Parameter, y : torch.Tensor):
# add = x + y; x = y = None
# mul = _stack0 * add; _stack0 = add = None
# return (mul,)
# Run fn with AOT
torch._dynamo.reset()
aot_fn = torch.compile(optimized_mod, backend="aot_eager")
aot_fn(x, y)
# Note: Dynamo recompilation guarding invalid grad
#
# This test is a spiritual equivalent to test_invalid_requires_grad_fake in test_autodispatch.py
# The point of this test is to invoke aot_autograd in a way that would normally trigger an assertion
# (This is what test_invalid_requires_grad_fake) does. However, the point of this test is to prove
# that we do not hit this assertion, as dynamo recompiles correctly and protects this condition.
#
# Subnote: The reason for us having test_invalid_requires_grad_fake utilizing fake tensors
# is because dynamo sends fake tensors down to aot_autograd.
@patch("torch._functorch.config.debug_assert", True)
def test_requires_grad_fake_via_dynamo_recompiles(self):
class F(torch.nn.Module):
def forward(self, x, y):
return (x + y,)
x = torch.randn(3, 3, requires_grad=True)
y = torch.randn(3, 3, requires_grad=True)
z = torch.randn(3, 3, requires_grad=False)
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
compare_equal_outs_and_grads(self, F(), fxy, (x, y))
compare_equal_outs_and_grads(self, F(), fxy, (x, z))
self.assertIn(
"""tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
failure_reason,
)
# Reset failure reason
failure_reason = None
self.assertEqual(cc.frame_count, 2)
torch._dynamo.reset() # for new backend
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fxz = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
compare_equal_outs_and_grads(self, F(), fxz, (x, z))
compare_equal_outs_and_grads(self, F(), fxz, (x, z))
self.assertEqual(cc.frame_count, 1)
self.assertTrue(failure_reason is None)
def test_double_backward_errors(self):
# Remove this test after we get double backward to actually work
for grad_output in (torch.tensor(1.0, requires_grad=True), None):
x = torch.tensor(1.0, requires_grad=True)
err = "torch.compile with aot_autograd does not currently support double backward"
# The following cases should be equivalent:
# (1) double backward entirely inside compiled function
def f1(x):
y = x.sin().exp()
(gx,) = torch.autograd.grad(
y, x, create_graph=True, grad_outputs=grad_output
)
torch.autograd.grad(gx, x)
return gx
compiled_f1 = torch.compile(backend="aot_eager")(f1)
f1(x)
with self.assertRaisesRegex(RuntimeError, err):
compiled_f1(x)
# (2) the second half of double backward outside compiled function
def f2(x):
y = x.sin().exp()
(gx,) = torch.autograd.grad(
y, x, create_graph=True, grad_outputs=grad_output
)
return gx
compiled_f2 = torch.compile(backend="aot_eager")(f2)
gx = compiled_f2(x)
with self.assertRaisesRegex(RuntimeError, err):
torch.autograd.grad(gx, x)
# (3) double backward entirely outside compiled function
def f3(x):
y = x.sin().exp()
return y
compiled_f3 = torch.compile(backend="aot_eager")(f3)
y = compiled_f3(x)
(gx,) = torch.autograd.grad(
y, x, create_graph=True, grad_outputs=grad_output
)
with self.assertRaisesRegex(RuntimeError, err):
torch.autograd.grad(gx, x)
# create_graph=False
def f4(x):
y = x.sin().exp()
return y
compiled_f4 = torch.compile(backend="aot_eager")(f4)
x = torch.tensor(1.0, requires_grad=True)
y = compiled_f4(x)
(gx,) = torch.autograd.grad(y, x, create_graph=False, grad_outputs=grad_output)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles(self):
class F(torch.nn.Module):
def forward(self, x, y):
x = x.trunc_()
y = y.trunc_()
return (x + y,)
x = torch.randn(3, 3, requires_grad=True)
x1, x2, x3, x4 = x.clone(), x.clone(), x.clone(), x.clone()
y = torch.randn(3, 3, requires_grad=True)
y1, y2, y4 = y.clone(), y.clone(), y.clone()
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
# Note: to prevent a recompilation between the two calls,
# we need to clone x and y on each use.
# fxy mutates the input's metadata, so otherwise dynamo will end up recompiling.
fxy(x1, y1)
fxy(x2, y2)
self.assertTrue(failure_reason is None)
# Reset failure reason
failure_reason = None
self.assertEqual(cc.frame_count, 1)
torch._dynamo.reset() # for new backend
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fxx = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
fxx(x3, x3)
fxx(x4, y4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['x'] is L['y']""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
class F(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, a, b, e, f):
a.trunc_()
b.trunc_()
return (a + b + self.mean) * e * f
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
self.assertTrue(failure_reason is None)
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1, 2, 2)
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)
torch._dynamo.reset()
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(c3, c3, 3, 3)
f(c4, d4, 3, 3)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['a'] is L['b']""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
z = None
class F(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, a, b, e, f):
a.trunc_()
b.trunc_()
return (a + b + z + self.mean) * e * f
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
z = a
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
self.assertTrue(failure_reason is None)
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1, 2, 2)
f(a2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg_list(self):
class F(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, e, f, a, b):
a.trunc_()
b.trunc_()
return (a + b + self.mean) * e[0] * f[0]
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
self.assertTrue(failure_reason is None)
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f([3, 2, 1], [4, 5, 6], a1, a1)
f([3, 2, 1], [4, 5, 6], a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)
torch._dynamo.reset()
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f([3, 2, 1], [4, 5, 6], c3, c3)
f([3, 2, 1], [4, 5, 6], c4, d4)
self.assertEqual(cc.frame_count, 2)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args_param(self):
class F(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mean = torch.nn.Parameter(torch.randn(3, 3))
def forward(self, a, b):
a.trunc_()
b.trunc_()
return a + b + self.mean
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2 = a.clone(), a.clone()
b1, b2 = b.clone(), b.clone()
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
self.assertTrue(failure_reason is None)
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1)
f(a2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)
torch._dynamo.reset()
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(c3, c3)
f(c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['a'] is L['b']""", failure_reason)
@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
class F(torch.nn.Module):
def forward(self, a, b, c, d):
a.trunc_()
b.trunc_()
c.trunc_()
d.trunc_()
return (a + b + c + d,)
a = torch.randn(3, 3, requires_grad=True)
b = torch.randn(3, 3, requires_grad=True)
a1, a2, a3, a4 = a.clone(), a.clone(), a.clone(), a.clone()
b1, b2, b3, b4 = b.clone(), b.clone(), b.clone(), b.clone()
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
self.assertTrue(failure_reason is None)
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a1, a1, a1, a1)
f(a2, b2, b2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)
torch._dynamo.reset()
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
c = torch.randn(3, 3, requires_grad=True)
d = torch.randn(3, 3, requires_grad=True)
c3, c4 = c.clone(), c.clone()
d3, d4 = d.clone(), d.clone()
f = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
f(a3, b3, c3, c3)
f(a4, b4, c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertIn("""L['c'] is L['d']""", failure_reason)
def test_alias_inputs(self):
def fn():
a = torch.tensor([1])
a = a[0:1]
b = a.squeeze()
a[0] = 0
if a[0] < 1e5:
pass
a[0] = 2
return b
ref_output = fn()
aot_fn = torch.compile(fn, backend="aot_eager")
actual_output = aot_fn()
self.assertEqual(ref_output, actual_output)
def test_grad_inputs_alias_inputs(self):
class Test(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, grad):
(x,) = ctx.saved_tensors
return x, grad
def fn(x, y):
return Test.apply(x, y)
x = torch.ones(1, requires_grad=True)
y = torch.ones(1, requires_grad=True)
compiled_fn = torch.compile(fn, backend="aot_eager")
out = compiled_fn(x, y)
out.sum().backward()
@expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)
@patch("torch._functorch.config.debug_assert", True)
def test_multiple_aot_autograd_calls_dupe_args(self):
# this is just dealing with the fact that
# aot_module_simplified expects submods to always return tuples/lists
class WrapperModule(torch.nn.Module):
def __init__(self, mod):
super().__init__()
self.mod = mod
def forward(self, *args):
out = self.mod(*args)
if isinstance(out, (list, tuple)):
return out
return (out,)
def compile_submod(input_mod, args):
from functorch.compile import nop
from torch._functorch.aot_autograd import aot_module_simplified
class WrapperModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.original = input_mod
self.submod = aot_module_simplified(input_mod, args, nop)
def forward(self, *args):
return self.submod(*args)
return WrapperModule()
def test_compile(fx_g, example_inps):
split_gm = torch.fx.passes.split_module.split_module(
fx_g, None, lambda node: 1 if "mul" in str(node) else 0
)
submod_1_inps = split_gm.submod_0(*example_inps)
split_gm.submod_0 = compile_submod(
WrapperModule(split_gm.submod_0), example_inps
)
split_gm.submod_1 = compile_submod(
WrapperModule(split_gm.submod_1), submod_1_inps
)
return split_gm
@torch.compile(backend=test_compile)
def f(a):
b, c = torch.ops.custom.maybe_dupe_op(a)
return (b.mul_(c),)
f(torch.ones(4))
f(torch.ones(6))
def test_nn_parameter_construction(self):
# https://github.com/pytorch/pytorch/issues/99569
def fn(x):
y = x.sin()
z = torch.nn.Parameter(torch.ones(1))
return y + z
x = torch.rand((4, 4))
opt_fn = torch.compile(fn, backend="aot_eager")
self.assertTrue(torch._dynamo.testing.same(fn(x), opt_fn(x)))
def test_aot_sequence_nr(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=(1, 1),
stride=1,
padding="same",
bias=True,
)
self.bn1 = torch.nn.BatchNorm2d(num_features=16)
self.relu1 = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
self.loss_fn = torch.nn.L1Loss()
def forward(self, x, target):
y = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = x + y
x = torch.flatten(x)
x = self.fc1(x)
output = self.loss_fn(x, target)
return (output,)
mod = Model()
mod.train()
x = torch.rand(100, 16, 32, 32, requires_grad=True)
target = torch.rand(1)
# Use dynamo export to get the fx graph module
g_mod, _ = torch._dynamo.export(mod, x, target)
def _prepare_model_args():
named_parameters = dict(g_mod.named_parameters(remove_duplicate=False))
named_buffers = dict(g_mod.named_buffers(remove_duplicate=False))
params_and_buffers = {
**dict(named_parameters),
**dict(named_buffers),
}
params_and_buffers_flat, params_spec = pytree.tree_flatten(
params_and_buffers
)
params_len = len(params_and_buffers_flat)
functional_call = create_functional_call(g_mod, params_spec, params_len)
return params_and_buffers_flat, functional_call
full_args, fn_to_trace = _prepare_model_args()
param_and_buf_len = len(full_args)
full_args.extend([x, target])
# aot_export requires a graph mod input of fwd graph
# returns the full fwd/bwd graph in graph mod format
with torch.enable_grad(), fx_traceback.preserve_node_meta():
fx_g, _, _, _ = _aot_export_function(
fn_to_trace,
full_args,
decompositions=None,
num_params_buffers=param_and_buf_len,
no_tangents=True,
)
# Walk all the nodes in fx graph.
# Write the resulting ops to a table
min_seq_nr = -1
seq_table = "SeqNr|OrigAten|SrcFn|FwdSrcFn\n"
for node in fx_g.graph.nodes:
if "call_" in node.op and "getitem" not in str(node.target):
seq_nr = node.meta.get("seq_nr", -1)
if seq_nr < 0:
continue
if min_seq_nr < 0:
min_seq_nr = seq_nr
source_fn_stack = node.meta.get("source_fn_stack", [])
orig_aten = node.meta.get("original_aten", "")
mod_name = ""
if len(source_fn_stack) > 0:
mod_name = source_fn_stack[-1][0]
# Make all seq_nr relative so it starts at 0
seq_nr = seq_nr - min_seq_nr
# For backward nodes, also test that metadata from the corresponding
# forward node is copied over.
fwd_source_fn_stack = node.meta.get("fwd_source_fn_stack", [])
fwd_mod_name = ""
if len(fwd_source_fn_stack):
fwd_mod_name = fwd_source_fn_stack[-1][0]
seq_table = (
seq_table + f"{seq_nr}|{orig_aten}|{mod_name}|{fwd_mod_name}\n"
)
self.maxDiff = None
self.assertExpectedInline(
seq_table,
dedent(
"""\
SeqNr|OrigAten|SrcFn|FwdSrcFn
0|aten.convolution.default|l__self___conv1|
0|aten.add.Tensor|l__self___bn1|
1|aten._native_batch_norm_legit_functional.default|l__self___bn1|
2|aten.relu.default|l__self___relu1|
2|aten.detach.default|l__self___relu1|
2|aten.detach.default|l__self___relu1|
3|aten.add.Tensor|add|
4|aten.view.default|flatten|
5|aten.view.default|l__self___fc1|
6|aten.t.default|l__self___fc1|
7|aten.addmm.default|l__self___fc1|
8|aten.view.default|l__self___fc1|
9|aten.sub.Tensor|l__self___loss_fn|
10|aten.abs.default|l__self___loss_fn|
11|aten.mean.default|l__self___loss_fn|
11|aten.ones_like.default||l__self___loss_fn
11|aten.expand.default||l__self___loss_fn
11|aten.div.Scalar||l__self___loss_fn
10|aten.sgn.default||l__self___loss_fn
10|aten.mul.Tensor||l__self___loss_fn
8|aten.view.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.mm.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.mm.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.sum.dim_IntList||l__self___fc1
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1
1|aten.native_batch_norm_backward.default||l__self___bn1
0|aten.convolution_backward.default||l__self___conv1
11|aten.add.Tensor||l__self___loss_fn
"""
),
)
def test_split_with_sizes_aot_autograd_cleans_up_traceback_meta(self):
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
def fn(result, split_sizes):
rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
return rs
example_inputs = (
torch.randn(32, requires_grad=True),
torch.tensor((7, 16, 9)),
)
outs = fn(*example_inputs)
setup_stacktrace_preservation_hooks([out.grad_fn for out in outs])
with fx_traceback.preserve_node_meta():
(outs[0].sum() + outs[1].sum() + outs[2].sum()).backward()
self.assertNotIn("grad_fn_seq_nr", fx_traceback.current_meta)
self.assertNotIn("in_grad_fn", fx_traceback.current_meta)
# https://github.com/pytorch/pytorch/issues/110121
def test_aot_export_joint_simple_repro(self):
class Mod(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(5, 7)
def forward(self, x):
return self.linear(x)
def mini_backend(gm, sample_inputs):
from torch._functorch.aot_autograd import aot_export_joint_simple
fake_mode = torch._dynamo.utils.detect_fake_mode(sample_inputs)
with patch.object(fake_mode, "allow_non_fake_inputs", True), fake_mode:
return aot_export_joint_simple(gm, sample_inputs, trace_joint=False)
sample_inputs = [torch.rand((3, 4, 5))]
model = Mod()
m_compiled = torch.compile(model, backend=mini_backend)
out_ref = model(*sample_inputs)
out_test = m_compiled(*sample_inputs)
self.assertEqual(out_ref, out_test)
# set donated_buffer=False due to create_graph=True
@torch._functorch.config.patch("donated_buffer", False)
def test_eager_sequence_nr(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=(1, 1),
stride=1,
padding="same",
bias=True,
)
self.bn1 = torch.nn.BatchNorm2d(num_features=16)
self.relu1 = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
self.loss_fn = torch.nn.L1Loss()
def forward(self, x, target):
y = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = x + y
x = torch.flatten(x)
x = self.fc1(x)
output = self.loss_fn(x, target)
return (output,)
def grad_with_create_graph(mod, x, target):
y = mod(x, target)
# Set create_graph=True to ensure that the sequence_nr
# for backward ops continues to count down.
(gx,) = torch.autograd.grad(
y[0], x, create_graph=True, grad_outputs=grad_output
)
return gx
x = torch.rand(100, 16, 32, 32, requires_grad=True)
target = torch.rand(1)
mod = Model()
args = [mod, x, target]
grad_output = torch.tensor(1.0, requires_grad=True)
compiled_f1 = torch.compile(backend="aot_eager")(grad_with_create_graph)
model_instance = compiled_f1
with profile(
activities=[torch.profiler.ProfilerActivity.CPU],
record_shapes=True,
) as kineto_prof:
res = model_instance(*args)
bwd_set = set()
prof_str = "SeqNr|Thread|FwdThread|Name\n"
for event in kineto_prof.events():
if event.sequence_nr >= 0:
prof_str = (
prof_str + f"{event.sequence_nr}|{event.thread}"
f"|{event.fwd_thread}|{event.name}|\n"
)
if re.search(r"Backward[01]", event.name):
bwd_set.add(event.sequence_nr)
self.assertTrue(len(bwd_set), 13)
def test_aot_grad_mode_mutation(self):
for compiler in ["aot_eager", "inductor"]:
def f(x):
y = x * x
torch.set_grad_enabled(False)
return y.clone(), y
f_compiled = torch.compile(f, backend=compiler, fullgraph=True)
torch.set_grad_enabled(True)
x = torch.ones(3, requires_grad=True) * 3
y_ref = f(x)
self.assertEqual(torch.is_grad_enabled(), False)
torch.set_grad_enabled(True)
y = f_compiled(x)
self.assertEqual(torch.is_grad_enabled(), False)
torch.set_grad_enabled(True)
self.assertEqual(y_ref, y)
self.assertIsNone(y_ref[0].grad_fn)
self.assertIsNone(y[0].grad_fn)
self.assertIsNotNone(y_ref[1].grad_fn)
self.assertIsNotNone(y[1].grad_fn)
# Check that the grad computed for the inputs, given the input, is the same
# The tangent to `y[0]`, which has grad_required=False, is irrelevant
self.assertEqual(
sum(y_ref[1].grad_fn(torch.tensor([-1.0, 2.0, 0.0]))),
sum(
x
for x in y[1].grad_fn.apply(None, torch.tensor([-1.0, 2.0, 0.0]))
if x is not None
),
)
def test_aot_autograd_raises_invalid_leaf_set(self):
@torch.compile
def f(x):
x.set_(torch.ones(2))
# We still want to make sure that this raises
x = torch.ones(2, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError, "is being used in an in-place operation"
):
f(x)
def test_aot_autograd_expand_mutation_functionalizes(self):
def fn(x):
y = x.expand(3, *x.shape)
y[0, 0].add_(5)
return y
opt_fn = torch.compile(fn, backend="aot_eager")
x = torch.arange(6)
x_opt = x.detach().clone()
self.assertEqual(fn(x), opt_fn(x_opt))
self.assertEqual(x, x_opt)
def test_aot_autograd_expand_mutation_backwards(self):
def fn(x, z):
y = x.expand(3, *x.shape)
y[1, 1].mul_(5)
ret = y * z
return ret
opt_fn = torch.compile(fn, backend="aot_eager")
x = torch.arange(6, dtype=torch.float)
z = x.detach().clone()
x_opt = x.detach().clone()
z_opt = x.detach().clone()
z.requires_grad = True
z_opt.requires_grad = True
res = fn(x, z)
opt_res = opt_fn(x_opt, z_opt)
self.assertEqual(res, opt_res)
res.sum().backward()
opt_res.sum().backward()
self.assertEqual(x, x_opt)
self.assertEqual(z.grad, z_opt.grad)
def test_data_ptr_access_copy(self):
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
with FakeTensorMode():
x = torch.randn(3)
y = copy.copy(x)
self.assertEqual(y.shape, x.shape)
def test_data_ptr_access_fails_in_forward(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
def _(x):
x.data_ptr()
return x.clone()
x = torch.randn(3)
def data_ptr_graph_input(x):
r0 = torch.ops.mylib.foo(x)
return r0
def data_ptr_graph_intermediate(x):
y = x.clone()
r0 = torch.ops.mylib.foo(y)
return r0
tests = [data_ptr_graph_input, data_ptr_graph_intermediate]
def ctx():
return self.assertRaisesRegex(
RuntimeError, "Cannot access data pointer"
)
for f in tests:
with ctx():
make_fx(f, tracing_mode="fake")(x)
with ctx():
make_fx(f, tracing_mode="symbolic")(x)
with ctx():
torch.compile(f, backend="eager", fullgraph=True)(x)
def test_data_ptr_access_fails_in_backward(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
backward_called = False
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad):
nonlocal backward_called
backward_called = True
grad.data_ptr()
return grad.clone()
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
def _(x):
return Foo.apply(x)
def f(x):
return torch.ops.mylib.foo(x)
x = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertTrue(backward_called)
# We don't know how to catch multiple mutations to the same memory location
@unittest.expectedFailure
def test_aot_autograd_expand_mutation_error(self):
def fn(x):
y = x.expand(3, *x.shape)
y[0:3, 0].add_(5)
return y
opt_fn = torch.compile(fn, backend="aot_eager")
x = torch.arange(6)
x_opt = x.detach().clone()
with self.assertRaises(Exception):
fn(x)
with self.assertRaises(Exception):
opt_fn(x_opt)
@torch._functorch.config.patch(donated_buffer=True)
def test_donated_buffer1(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
@torch.compile()
def relu(x):
return torch.nn.functional.relu(x)
with self.assertLogs(logger_name, level="INFO") as captured:
relu(torch.rand([3, 3], requires_grad=True)).sum().backward()
if is_dynamic_shape_test(self._testMethodName):
# an extra symint exists
expected_msg = "bw_donated_idxs=[1]"
else:
expected_msg = "bw_donated_idxs=[0]"
# le is a donated buffer from relu
FileCheck().check(expected_msg).run("\n".join(captured.output))
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer2(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
# we will re-use the graph for g across f1 and f2
@torch.compile()
def g(activation, param2):
return torch.matmul(activation, param2)
def f(inp, param1, param2):
activation = inp + param1
return g(activation, param2)
inp = torch.ones(4, 4)
param1 = torch.ones(4, 4, requires_grad=True)
param2 = torch.ones(4, 4, requires_grad=True)
with self.assertLogs(logger_name, level="INFO") as captured:
f(inp, param1, param2).sum().backward()
FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer3(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
# we will re-use the graph for g across f1 and f2
@torch.compile()
def g(activation, param2):
return torch.matmul(activation, param2)
def f(inp, param1, param2):
# exp saves it output (the activation) for bw
activation = torch.exp(inp + param1)
return g(activation, param2)
inp = torch.ones(4, 4)
param1 = torch.ones(4, 4, requires_grad=True)
param2 = torch.ones(4, 4, requires_grad=True)
with self.assertLogs(logger_name, level="INFO") as captured:
f(inp, param1, param2).sum().backward()
FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer4(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([2, 2]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.relu(x) + self.param
mod = Mod()
mod = torch.compile(mod)
inp = torch.ones([2, 2], requires_grad=True)
with self.assertLogs(logger_name, level="INFO") as captured:
mod(inp).sum().backward()
# Forward graph:
# %primals_1 : [num_users=1] = placeholder[target=primals_1]
# %primals_2 : [num_users=1] = placeholder[target=primals_2]
# %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%relu, %primals_1), kwargs = {})
# %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
# return [add, le]
#
# `le` is a donated buffer
FileCheck().check("bw_donated_idxs=[0]").run("\n".join(captured.output))
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer5(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
@torch.compile()
def f(x, z):
y = x.view(2, 3)
z = torch.nn.functional.relu(z)
return torch.mm(y, x) + z
inp = [
torch.rand([3, 2], requires_grad=True),
torch.rand([2, 2], requires_grad=True),
]
with self.assertLogs(logger_name, level="INFO") as captured:
f(*inp).sum().backward()
# Forward graph:
# %primals_1 : [num_users=3] = placeholder[target=primals_1]
# %primals_2 : [num_users=1] = placeholder[target=primals_2]
# %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%primals_1, [2, 3]), kwargs = {})
# %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%primals_2,), kwargs = {})
# %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %primals_1), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mm, %relu), kwargs = {})
# %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
# return [add, primals_1, le]
#
# `le` is a donated buffer but primals_1 is not.
FileCheck().check("bw_donated_idxs=[1]").run("\n".join(captured.output))
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer6(self):
if is_dynamic_shape_test(self._testMethodName):
# parameters should not be dynamic shape
# torch._dynamo.exc.Unsupported: Parameter not python_constant:
# SymNodeVariable() is not a constant
return
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
def fn(x):
p = torch.nn.Parameter(x + 123)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x = torch.randn(16)
with self.assertLogs(logger_name, level="INFO") as captured:
p, r = opt(x)
r.sum().backward()
FileCheck().check("bw_donated_idxs=[]").run("\n".join(captured.output))
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer_with_retain_or_create_graph1(self):
# Gives non-empty bw_donated_idxs
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([3, 3]))
def forward(self, x):
return torch.nn.functional.relu(x) + self.param
inp = torch.randn(3, 3, requires_grad=True)
mod = torch.compile(Mod())
for _ in range(5):
mod(inp).sum().backward()
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer_with_retain_or_create_graph2(self):
# Gives non-empty bw_donated_idxs
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([3, 3]))
def forward(self, x):
return torch.nn.functional.relu(x) + self.param
inp = torch.randn(3, 3, requires_grad=True)
mod = torch.compile(Mod())
out = mod(inp).sum()
for _ in range(5):
out.backward(retain_graph=True)
out.backward()
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer_with_retain_or_create_graph3(self):
# Gives non-empty bw_donated_idxs
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([3, 3]))
def forward(self, x):
return torch.nn.functional.relu(x) + self.param
inp = torch.randn(3, 3, requires_grad=True)
mod = torch.compile(Mod())
mod(inp).sum().backward(create_graph=True)
out = mod(inp).sum()
for _ in range(5):
out.backward(retain_graph=True)
out.backward()
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer_with_retain_or_create_graph4(self):
# Gives non-empty bw_donated_idxs
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.zeros([3, 3]))
def forward(self, x):
return torch.nn.functional.relu(x) + self.param
inp = torch.randn(3, 3, requires_grad=True)
mod = torch.compile(Mod())
mod(inp).sum().backward()
out = mod(inp).sum()
with self.assertRaisesRegex(
RuntimeError,
r"This backward function was compiled with non-empty donated "
r"buffers which requires create_graph=False and retain_graph=False. "
r"Please keep backward\(create_graph=False, retain_graph=False\) "
r"across all backward\(\) function calls, or set "
r"torch._functorch.config.donated_buffer=False to disable "
r"donated buffer.",
):
out.backward(retain_graph=True)
def _get_guard_failure_on_overlapping_view_inputs(self, f, argsfn1, argsfn2):
# Compile and run f twice, using the arguments generated by argsfn1 and argsfn2.
#
# This function expects that the second argument set will trigger a recompilation,
# which shall be returned in the end.
guard_failure = []
def guard_fail_fn(failure):
nonlocal guard_failure
guard_failure.append(failure[0])
input = torch.ones(20)
opt_input = input.clone().detach()
opt_f = torch._dynamo.optimize(
"aot_eager", dynamic=True, guard_fail_fn=guard_fail_fn
)(f)
out0 = f(*argsfn1(input))
opt_out0 = opt_f(*argsfn1(opt_input))
self.assertEqual(out0, opt_out0)
out1 = f(*argsfn2(input))
opt_out1 = opt_f(*argsfn2(opt_input))
self.assertEqual(out1, opt_out1)
# Check that we only have one instance of guard failure, and that it is due to
# the overlapping state not matching.
self.assertEqual(len(guard_failure), 1)
return guard_failure[0]
def test_inputs_overlapping_with_mutation_recompile(self):
# Check that the overlap guard actually fails when we run the second time with
# args that have no storage overlap.
def f(*args):
for a in args:
a.add_(1)
return args[0]
def overlapping_args(x):
return x[:5], x[7:13], x[9:]
def non_overlapping_args(x):
return x[:5], x[7:13], x[13:15]
guard_failure = self._get_guard_failure_on_overlapping_view_inputs(
f, overlapping_args, non_overlapping_args
)
self.assertExpectedInline(
guard_failure,
"""0/0: check_overlapping(overlapping=[L['args'][1], L['args'][2]], non_overlapping=[L['args'][0]])""",
)
def test_different_inputs_overlapping_set_with_mutation(self):
# Check that the overlap guard actually fails when we run the second time with
# arguments whose overlapping set is a superset of the set of arguments used in
# the first time.
def f(a, b, c, d):
a.mul_(2)
return a + b + c + d
def a_b_overlapping_args(x):
return x[:5], x[4:9], x[10:15], x[15:]
def a_b_c_overlapping_args(x):
return x[:5], x[4:9], x[8:13], x[15:]
guard_failure = self._get_guard_failure_on_overlapping_view_inputs(
f, a_b_overlapping_args, a_b_c_overlapping_args
)
self.assertExpectedInline(
guard_failure,
"""0/0: check_overlapping(overlapping=[L['a'], L['b']], non_overlapping=[L['c'], L['d']])""",
)
def _test_no_storage_overlap_guards(self, f, argsfn):
# Compile f with aot_eager backend, and run it with the argument set returned by
# argsfn function. Meanwhile, keep track of the aotautograd_gurads, so as to make
# sure no StorageOverlap guard was added.
class Compiler:
def __init__(self):
self.counter = CompileCounterWithBackend("aot_eager")
def __call__(self, *args, **kwargs):
# Instead of checking here, we need to check afterwards, since the
# StorageOverlap guard is only added later.
self.guards = TracingContext.get().guards_context.aotautograd_guards
return self.counter(*args, **kwargs)
compiler = Compiler()
input = torch.arange(20)
opt_input = input.clone().detach()
out = f(*argsfn(input))
opt_out = torch._dynamo.optimize(compiler, dynamic=True)(f)(*argsfn(opt_input))
self.assertEqual(out, opt_out)
self.assertEqual(compiler.counter.frame_count, 1)
# Check none of the AOTAutograd guards are StorageOverlap guards.
for g in compiler.guards:
self.assertNotIsInstance(g, StorageOverlap)
def test_no_storage_overlap_guards_no_mutation(self):
def f(a, b):
return a + b
def overlapping_args(input):
return input[:10], input[5:15]
self._test_no_storage_overlap_guards(f, overlapping_args)
def test_no_storage_overlap_guards_no_aliasing(self):
def f(a, b):
a.add_(1)
b.add_(1)
return a
def non_overlapping_args(input):
return input[:10], torch.arange(20)[5:15]
self._test_no_storage_overlap_guards(f, non_overlapping_args)
def test_inputs_overlapping_with_mutation_stress(self):
# Stress test for StorageOverlap guard.
#
# Create 100 non-overlapping tensor views, and an extra one that overlaps with
# the first 50 of them. Then, make sure that none of the produced ShapeEnv
# guards came from the overlapping computation.
def f(*args):
for a in args:
a.add_(1)
return args[0]
def overlapping_args(input):
return (
# 100 non-overlapping tensors of size 10.
*input.split(10),
# A tensor that overlaps with half of the tensors above.
input[4:44],
)
class Compiler:
def __init__(self):
self.counter = CompileCounterWithBackend("aot_eager")
def __call__(self, *args, **kwargs):
self.compile_context = CompileContext.get()
return self.counter(*args, **kwargs)
compiler = Compiler()
opt_f = torch._dynamo.optimize(compiler, dynamic=True)(f)
input = torch.arange(1_000)
opt_input = input.clone().detach()
out0 = f(*overlapping_args(input))
opt_out0 = opt_f(*overlapping_args(opt_input))
self.assertEqual(out0, opt_out0)
# Check that none of the produced ShapeEnv guards came from compute_overlapping_inputs
# function.
overlapping_computation_fn = "compute_overlapping_inputs"
shape_env_guards = compiler.compile_context.shape_env_guards
for g in shape_env_guards:
self.assertNotIn(overlapping_computation_fn, g)
# Check that we have no more than 500 ShapeEnv guards.
#
# Note: this is an arbitrary number. So, we might have to change it in the future.
# However, at the time this change was introduced, it went down from 15154 to 403.
self.assertLess(len(shape_env_guards), 1000)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|