1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365
|
# Owner(s): ["module: ProxyTensor"]
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import unittest
import warnings
import torch.nn.utils._stateless as stateless
import operator
from collections.abc import Iterable
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch._subclasses.fake_tensor import DynamicOutputShapeException
from torch._decomp import decomposition_table
from torch.testing._internal.common_device_type import ops
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy
from torch.utils._pytree import tree_map
from torch import nn
import re
import types
import functools
import itertools
aten = torch.ops.aten
try:
import sympy # noqa: F401
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
HAS_CUDA = torch.cuda.is_available()
def process_failures():
"""
Takes file containing failures like
FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950
and processes them into a list of opinfo xfails
"""
f = open('pytest_failures')
failures = f.readlines()
failures = [i.strip() for i in failures]
def process_failure_string(s, matcher):
out = re.search(matcher, s)
return out.groups()
SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
def create_normalized_name(op):
if op.variant_test_name == '':
s = op.name
else:
s = f"{op.name}.{op.variant_test_name}"
return s.replace('.', '_')
remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
print("symbolic_tensor_failures = {")
for failure, reason in failures:
print(f" xfail{remap_opinfo[failure]}, # {reason}")
print("}")
def copy_func(f):
"""Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g
# Copied from functorch
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, True)
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, False)
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db
for xfail in to_skip:
op_name, variant_name, device_type, dtypes, expected_failure = xfail
matching_opinfos = [o for o in all_opinfos
if o.name == op_name and o.variant_test_name == variant_name]
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
for opinfo in matching_opinfos:
decorators = list(opinfo.decorators)
if expected_failure:
decorator = DecorateInfo(unittest.expectedFailure,
test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes)
decorators.append(decorator)
else:
decorator = DecorateInfo(unittest.skip("Skipped!"),
test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes)
decorators.append(decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
def _create_new_input(x):
if not isinstance(x, torch.Tensor):
return x
if x.dtype != torch.float:
return x + 1
if x.is_leaf:
return torch.rand_like(x, requires_grad=x.requires_grad)
else:
return torch.rand_like(x)
"""
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
"""
class UnwrapTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
dtype=tensor.dtype,
device=tensor.device,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
)
r._tensor = tensor
return r
def __repr__(self):
# TODO: consider all_gather the local tensors for better debugging
return f"UnwrapTensor({self._tensor})"
__torch_function__ = _disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
ret = e
if isinstance(e, UnwrapTensor):
ret = e._tensor.cos()
return ret
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
return func(*args, **kwargs)
class TestGenericProxyTensor(TestCase):
# WARNING: if any of your inputs are index tensors, DO NOT use this
# function
def _test(self, f, inps):
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
new_inps = tree_map(_create_new_input, inps)
r1 = fx_f(*new_inps)
r2 = f(*new_inps)
self.assertEqual(r1, r2)
def test_make_fx_simple(self):
def f(x):
return torch.sin(x)
self._test(f, (torch.randn(3),))
def test_scalar_device(self, device='cpu'):
def f(a, b):
return a + b
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
def test_isolated_graphmodule(self):
def is_any_sum(gm):
return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
def is_any_digamma(gm):
return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
def is_any_sigmoid(gm):
return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
def inner(x):
return torch.sum(x)
def f(x):
gm = get_isolated_graphmodule(inner, (x,), {})
self.assertTrue(is_any_sum(gm))
return x + torch.randn(x.shape)
# get_isolated_graphmodule uses make_fx internally that shouldn't be traced
# by the outer make_fx call
traced = make_fx(f)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
# When factory functions are used, they should not be traced
# by the outer make_fx call
def inner_with_factory():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((10, 10), val).sum()
def f1(x):
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2(x):
gm = get_isolated_graphmodule(f1, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify nested make_fx calls don't make factory functions to be leaked
# into the outer graph
def f2(x):
gm = make_fx(f1)(x)
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertTrue(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify interaction with non-ProxyTensor modes
from torch.testing._internal.logging_tensor import LoggingTensorMode
def f1_logging(x):
with LoggingTensorMode():
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2_logging(x):
with LoggingTensorMode(), LoggingTensorMode():
gm = get_isolated_graphmodule(f1_logging, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2_logging)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify interaction with another tensor subclass
# This case currently doesn't work and should raise an error
# See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
from torch.testing._internal.logging_tensor import LoggingTensor
def f1_logging_tensor(x):
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2_logging_tensor(x):
x = LoggingTensor(x)
gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2_logging_tensor)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
self.assertTrue(is_any_digamma(traced))
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
def f(x):
y = x.new_zeros(x.size())
y.copy_(x)
return y
def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
# When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
# to still be (re-entrantly) enabled, so that the `torch.zero()` call
# returns a ProxyTensor.
out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
self.assertExpectedInline(out.code, """\
def forward(self, x_1):
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
return copy_
""")
def test_make_fx_reentrant_dispatch(self):
def f(x):
return torch.ops.aten.norm.Scalar(x, 2.0)
def norm_decomp(x, p=2.0):
if p != 2.0:
raise RuntimeError("can't handle with p != 2")
return torch.sqrt(torch.sum(torch.square(x)))
decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
for n in traced.graph.nodes:
self.assertTrue("square" not in str(n.target))
self.assertTrue("norm" not in str(n.target))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self):
mod = torchvision.models.resnet18()
# An old version of this test called the module directly. This works
# for tracing_mode == "real", but for fake tensors, we also have to
# ensure that the parameters and buffers get wrapped in fake tensors
# because free fake tensors are not supported. Fortunately stateless
# does precisely this for us.
def f(x, params, buffers):
for p in params.values():
p.grad = None
loss = stateless.functional_call(mod, {**params, **buffers}, (x,)).sum()
# I could have done this with the functional API, but there is
# plenty of exercising this; I want to show mutating API still
# works
loss.backward()
return [p.grad for p in params.values()]
inp = torch.randn(3, 3, 250, 250)
self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
def test_varargs(self):
def f(*args):
return sum(args)
self._test(f, [torch.randn(2), torch.randn(2)])
def test_proxy_tensor(self):
def f_grad(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
def f_backward(x):
val = x.cos().cos().sum()
val.backward()
return x.grad
for f in [f_grad, f_backward]:
self._test(f, [torch.randn(3, requires_grad=True)])
def test_inplace_metadata(self):
def f(x):
x = x.clone()
x.unsqueeze_(-1)
assert x.shape[-1] == 1
return x
self._test(f, [torch.randn(5)])
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
# default behavior should trace factory functions
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(
any(
node.target == aten.randn.default
for node in traced.graph.nodes
)
)
def test_make_fx_overloads(self):
def f(x):
return x.cos() + torch.randn(x.shape)
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function']))
def test_tensor_constants(self):
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
self._test(f, [])
def test_allclose(self):
def f(a, b):
return torch.allclose(a, b)
self.assertRaisesRegex(
RuntimeError, "data-dependent",
lambda: make_fx(f, tracing_mode=self.tracing_mode)(
torch.zeros(3), torch.zeros(3)
)
)
def test_constant_proxy_tensor_mut(self):
def f():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((100, 100), val)
g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
# In case we mutated shared state in the g graph!
self.assertEqual(g(), f())
def test_constant_unbind(self):
def f():
val = torch.tensor([2])
r, = torch.unbind(val, 0)
return r.item()
g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
def test_constant_blowup(self):
def f():
val = torch.tensor([2])
blowup = val.repeat(1000)
return blowup.sum().item()
self.assertRaisesRegex(
RuntimeError, "data-dependent",
lambda: make_fx(f, tracing_mode=self.tracing_mode)()
)
def test_constant_random(self):
def f():
val = torch.tensor([2.0])
val.normal_()
return val.item()
self.assertRaisesRegex(
RuntimeError, "data-dependent",
lambda: make_fx(f, tracing_mode=self.tracing_mode)()
)
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
x = torch.rand((4, 4))
fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
found_silu = False
for n in fx_module.graph.nodes:
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
found_silu = True
self.assertTrue(found_silu)
new_graph = torch.fx.Graph()
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
DecompositionInterpreter(
fx_module,
new_graph=new_graph,
decomposition_table=silu_decomp_table,
).run(x)
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
for n in decomposed_module.graph.nodes:
self.assertTrue(n.target != torch.ops.aten.silu)
self.assertTrue(n.target != torch.ops.aten.silu.default)
self.assertEqual(fx_module(x), decomposed_module(x))
def test_make_fx_model_fwd_bwd(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
def f(x, params):
out = stateless.functional_call(model, params, x).sum()
out.backward()
return list(params.values())
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
# fx may change the order of parameters in list, so using set() to compare
self.assertTrue(
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
or
torch.allclose(fx_f(input, params)[0], f(input, params)[1])
)
self.assertTrue(
torch.allclose(fx_f(input, params)[1], f(input, params)[0])
or
torch.allclose(fx_f(input, params)[1], f(input, params)[1])
)
def test_make_fx_model_double_param(self):
class Emformer(torch.nn.Module):
def __init__(
self,
input_dim: int = 256,
) -> None:
super().__init__()
self.layer_norm = torch.nn.LayerNorm(input_dim)
def forward(mod_self, x): # noqa: B902
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
y = mod_self.layer_norm(x)
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
z = mod_self.layer_norm(y)
return z
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function'])
self.assertEqual(len(ops), 2)
def test_make_fx_model_fwd_bwd_wgtupdate(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
def f(args, params, buffers):
for p in params.values():
p.grad = None
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}
out = stateless.functional_call(model, params_and_buffers, args)
out.sum().backward()
return [p - 1e-4 * p.grad for p in params.values()]
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
# fx may change the order of parameters in list, so using set() to compare
# also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
self.assertTrue(
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
or
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
)
self.assertTrue(
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
or
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
)
def test_trace_subclasses(self):
def f1(x):
x = UnwrapTensor(x)
y = x * 2
return y
def f2(x):
wrapped = UnwrapTensor(x)
y = x * wrapped
return y
inp = [torch.randn(5)]
self._test(f1, inp)
self._test(f2, inp)
def test_partial_decomp(self):
def f(a, b, c):
x = torch.addmm(a, b, c)
y = torch.addmm(a, b, c, beta=2, alpha=1)
return x + y
inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
fx_g = make_fx(f)(*inps)
def addmm(a, b, c, beta=1, alpha=1):
if beta == 1 and alpha == 1:
return NotImplemented
return beta * a + alpha * (b @ c)
decomposed_fx = make_fx(f, {aten.addmm.default: addmm})(*inps)
self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
def test_decomp_of_capture(self):
val = torch.randn(5)
def f(x):
return x.t() + val.t()
def nop(x):
return x.cos()
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_amp_cache(self):
layer = torch.nn.Conv2d(3, 3, 3).cuda()
def f(x, w):
return torch.nn.functional.conv2d(x, w, stride=layer.stride)
inp = torch.randn(4, 3, 10, 10, device='cuda')
with torch.autocast('cuda'):
out_graph = make_fx(f)(inp, layer.weight).graph
out_graph2 = make_fx(f)(inp, layer.weight).graph
self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
for a, b in zip(out_graph.nodes, out_graph2.nodes):
self.assertEqual(a.op, b.op)
def test_has_proxy(self):
foo = torch.randn(5)
def f(x):
self.assertFalse(has_proxy(foo))
self.assertTrue(has_proxy(x))
y = x.cos()
self.assertTrue(has_proxy(y))
return y
self.assertFalse(has_proxy(torch.randn(5)))
make_fx(f)(torch.randn(5))
def test_strides(self):
def f(x):
self.assertTrue(x.is_contiguous())
self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
x = x.permute(0, 3, 1, 2)
self.assertFalse(x.is_contiguous())
self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
return x
make_fx(f)(torch.randn(2, 3, 4, 5))
def f(x):
self.assertTrue(x.is_contiguous())
y = x[:, 1]
self.assertFalse(y.is_contiguous())
y = x[:, ::2]
self.assertFalse(y.is_contiguous())
return x.cos()
make_fx(f)(torch.randn(2, 3, 4, 5))
class TestGenericProxyTensorReal(TestGenericProxyTensor):
tracing_mode = "real"
class TestGenericProxyTensorFake(TestGenericProxyTensor):
tracing_mode = "fake"
def xfail_inherited_tests(tests):
"""
Given a list of test names which are defined by a superclass of the
class this decorates, mark them as expected failure. This is useful
if you are doing poor man's parameterized tests by subclassing a generic
test class.
"""
def deco(cls):
for t in tests:
# NB: expectedFailure operates by mutating the method in question,
# which is why you have to copy the function first
setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t))))
return cls
return deco
@skipIfNoSympy
@xfail_inherited_tests([
"test_inplace_metadata",
"test_mode_tracing_factory_function",
"test_make_fx_overloads",
"test_make_fx_model_fwd_bwd_wgtupdate",
"test_make_fx_model_fwd_bwd",
"test_proxy_tensor",
"test_resnet18_backward_trace",
"test_trace_subclasses",
])
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
tracing_mode = "symbolic"
del TestGenericProxyTensor
class TestRealProxyTensor(TestCase):
pass
class TestFakeProxyTensor(TestCase):
def test_issue82547(self):
x = nn.Parameter(torch.randn(3, 3))
def f():
return torch.ops.aten.t.default(x)
self.assertRaisesRegex(Exception, "non-Fake Tensor", lambda: make_fx(f, tracing_mode="fake")())
class A(torch.Tensor):
pass
x = A(torch.randn(3, 3))
self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")())
def test_use_fake_and_tensor(self):
def f(x, y):
z = torch.tensor([2.0, 3.0])
return x + y + z
g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
x, y = torch.randn(2), torch.randn(2)
self.assertEqual(g(x, y), f(x, y))
def test_alias(self):
def f(x):
return torch.ops.aten.alias(x)
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
# NB: this should not have a detach call
self.assertExpectedInline(r, """\
def forward(self, x_1):
alias = torch.ops.aten.alias.default(x_1); x_1 = None
return alias""")
def _get_node(fx_g, cond):
for n in fx_g.graph.nodes:
if cond(n):
return n
raise AssertionError
def _get_free_symbols(shape_env):
vars = tuple(shape_env.var_to_val.keys())
return len([var for var in vars if var not in shape_env.replacements])
def _trace(f, *args):
inps = [torch.randn(arg) for arg in args]
return make_fx(f, tracing_mode="symbolic")(*inps)
# TODO: Need to test the guards themselves specifically as well
@skipIfNoSympy
class TestSymbolicTracing(TestCase):
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
"""
Tests fn traced with trace_inputs against test_inputs
Also returns shape env
"""
trace_inputs = [torch.randn(shape) for shape in trace_inputs]
traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
for input in test_inputs:
input = [torch.randn(shape) for shape in input]
rx, ry = traced_f(*input), fn(*input)
if assert_eq:
self.assertEqual(rx, ry)
return traced_f.shape_env
def test_unary(self):
def f(x):
assert x.shape[0] < 20
return x.cos()
test_inputs = []
test_inputs.append([(2, 5)])
test_inputs.append([(6, 8)])
shape_env = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(4, 5)))
self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(25, 5)))
# TODO: There should eventually be guards for contiguity, but they're
# not currently being done yet
assert len(shape_env.guards) == 1, "\n" + shape_env.format_guards()
def test_binary_broadcast(self):
def f(a, b):
c = a * b
return c
test_inputs = []
test_inputs.append([(1, 5), (3, 1)])
test_inputs.append([(1, 4), (4, 1)])
shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs)
assert len(shape_env.guards) == 0
def test_multiply_shape(self):
def f(a):
return torch.empty(a.shape[0] * 2)
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
mul = sym_size * 2; sym_size = None
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
detach = torch.ops.aten.detach.default(empty); empty = None
return detach""")
def test_symint_to_tensor(self):
def f(a):
return a / a.shape[0]
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0)
div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
return div""")
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0)
sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
return div""")
def test_cat(self):
def f(a, b):
val = torch.mul(a, b)
out = torch.cat([val, val])
if out.shape[0] * out.shape[1] > 20:
out = out.cos()
return out
test_inputs = []
test_inputs.append([(1, 5), (6, 1)])
test_inputs.append([(1, 4), (3, 1)])
shape_env = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(1, 10), torch.randn(6, 1)))
self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(1, 2), torch.randn(4, 1)))
assert len(shape_env.guards) == 1
def test_new_empty(self):
def f(a, b):
return a.new_empty(b.shape[0], b.shape[1] * 2)
self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False)
def test_size_with_tensor(self):
def f(tensor):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
return tensor.new_empty(batch_shape)
a = torch.randn(3, 800, 1199)
self.assertRaisesRegex(
RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a)
)
def test_expand(self):
def f(a):
b = torch.mul(a, a)
c = b.expand(a.shape)
return c
self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
def test_symbolic_meta(self):
def f(a, b):
d = a.new_empty(a.shape[0] + b.shape[0])
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj() == meta_d.meta['val'].expr)
def test_return_symint(self):
def f(x):
return x.shape[0], x.cos(), x.shape[0] / 5
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
def f(x):
return x.shape
self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
def _assert_no_guards(self, fx_g, free_symbols):
assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
def test_guards_equal(self):
def f(a, b):
return a * b
# NB: Numbers are carefully chosen to avoid duck shaping from applying
fx_g = _trace(f, (5, 6), (5, 6))
self._assert_no_guards(fx_g, 2)
fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (5, 1), (1, 6))
self._assert_no_guards(fx_g, 2)
def f(a, b, c, d):
a = a + b
cat = torch.cat([c, d])
return a + cat
fx_g = _trace(f, 7, 7, 4, 3)
self._assert_no_guards(fx_g, 2)
def f(a, b, c, d, e):
vals = [a, b, c, d, e]
x = a
for idx in range(len(vals) - 1):
x = torch.cat([x, vals[idx]]) + vals[idx + 1]
return x
fx_g = _trace(f, 2, 4, 8, 16, 32)
self._assert_no_guards(fx_g, 1)
def f(a, b):
a = a.view(b.shape[0])
return a + b.sum()
fx_g = _trace(f, (4, 2), 8)
self._assert_no_guards(fx_g, 2)
fx_g = _trace(f, (4, 2), (8, 5))
self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (2, 3, 4), 24)
self._assert_no_guards(fx_g, 3)
def test_nonidentity_transitive_guards(self):
def f(a, b, c, d, e):
vals = [a, b, c, d, e]
cat_vals = []
for idx in range(len(vals) - 1):
cat_vals.append(torch.cat([vals[idx], vals[idx]]))
final_vals = []
for a, b in reversed(list(zip(cat_vals, vals[1:]))):
final_vals.append(a + b)
return final_vals
fx_g = _trace(f, 2, 4, 8, 16, 32)
self._assert_no_guards(fx_g, 1)
make_fx_failures = {
# unknown
xfail('allclose'),
xfail('equal'),
# empty
skip('new_empty'),
skip('empty_like'),
skip('empty'),
# flaky
skip('linalg.lstsq', 'grad_oriented'),
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
skip('linalg.lstsq'), # flaky, probably just a precision issue
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
xfail('quantile'),
xfail('nanquantile'),
xfail('narrow'),
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
# proxy tensor doesn't support sparse correctly right now
skip('to_sparse'),
# segfaults
skip('block_diag'),
}
fake_tensor_failures = {
# FakeTensor fallback doesn't work
xfail('segment_reduce', 'lengths'),
xfail('multinomial'),
xfail('cholesky'),
xfail('cholesky_inverse'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),
}
symbolic_tensor_failures = {
# Needs complex-value support
xfail('polar'),
xfail('linalg.eig'),
xfail('linalg.eigvals'),
skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
xfail('argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
xfail('argsort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
xfail('as_strided_scatter', ''), # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition
xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition
xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition
xfail('diagonal_scatter', ''), # aten.diagonal_scatter.default - couldn't find symbolic meta function/decomposition
xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition
xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
xfail('einsum', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.fftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.fftshift', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.hfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.hfft', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('fft.hfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ifft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ifft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ifftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ifftshift', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ihfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ihfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.ihfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.irfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.irfft', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('fft.irfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.rfft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.rfft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('fft.rfftn', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('unflatten', ''), # RuntimeError: Trying to call aten.size on a tensor with symbolic shapes...
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition
xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition
xfail('gradient', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...
xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('index_add', ''), # Float
xfail('index_copy', ''), # Expected a long tensor for index, but got Float
xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition
xfail('index_reduce', ''), # Float
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
xfail('lerp', ''), # aten.lerp.Scalar - couldn't find symbolic meta function/decomposition
xfail('linalg.cholesky', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.cholesky_ex', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.cond', ''), # Tensors of type TensorImpl do not have numel
xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
xfail('linalg.det', 'singular'), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct...
xfail('linalg.inv', ''), # aten.linalg_inv_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.inv_ex', ''), # aten.linalg_inv_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.ldl_factor', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.ldl_factor_ex', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decompos...
xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape
xfail('linalg.matrix_norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
xfail('linalg.matrix_rank', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.matrix_rank', 'hermitian'), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.norm', ''), # TensorImpl do not have numel
xfail('linalg.norm', 'subgradients_at_zero'), # TensorImpl do not have numel
xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition
xfail('linalg.pinv', 'singular'), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.pinv', 'hermitian'), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decompo...
xfail('linalg.qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de...
xfail('linalg.svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('linalg.svdvals', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition
xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32
xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
xfail('max', 'reduction_with_dim'), # aten.max.dim - couldn't find symbolic meta function/decomposition
xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
xfail('meshgrid', 'list_of_tensors'), # Tensors of type TensorImpl do not have numel
xfail('meshgrid', 'variadic_tensors'), # Tensors of type TensorImpl do not have numel
xfail('min', 'reduction_with_dim'), # aten.min.dim - couldn't find symbolic meta function/decomposition
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
xfail('msort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d.default - couldn't find symbolic meta func...
xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.conv1d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.conv2d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.cosine_embedding_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('nn.functional.dropout2d', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.dropout3d', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.dropout', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun...
xfail('nn.functional.embedding', ''), # argument 'size' must be tuple of ints, but found element of type tor...
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
xfail('nn.functional.glu', ''), # aten.glu.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int'
xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d...
xfail('nn.functional.interpolate', 'bilinear'), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function...
xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d...
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
xfail('nn.functional.pad', 'circular'), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.pad', 'constant'), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo...
xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('outer', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('pca_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition
xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition
xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('roll', ''), # Tensors of type TensorImpl do not have numel
xfail('round', ''), # aten.round.default - couldn't find symbolic meta function/decomposition
xfail('round', 'decimals_0'), # aten.round.decimals - couldn't find symbolic meta function/decomposition
xfail('round', 'decimals_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition
xfail('round', 'decimals_neg_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition
xfail('scatter_add', ''), # aten.scatter_add.default - couldn't find symbolic meta function/decomposition
xfail('scatter', ''), # aten.scatter.src - couldn't find symbolic meta function/decomposition
xfail('scatter_reduce', 'amax'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
xfail('scatter_reduce', 'amin'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
xfail('scatter_reduce', 'mean'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
xfail('scatter_reduce', 'prod'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ...
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
xfail('select', ''), # aten.select.int - couldn't find symbolic meta function/decomposition
xfail('select_scatter', ''), # aten.select_scatter.default - couldn't find symbolic meta function/decomposition
xfail('slice_scatter', ''), # aten.slice_scatter.default - couldn't find symbolic meta function/decomposition
xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition
xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition
xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me...
xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me...
xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition
xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decomposition
xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f...
xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta...
xfail('special.laguerre_polynomial_l', ''), # aten.special_laguerre_polynomial_l.default - couldn't find symbolic meta...
xfail('special.log_ndtr', ''), # aten.special_log_ndtr.default - couldn't find symbolic meta function/decomposition
xfail('special.modified_bessel_i0', ''), # aten.special_modified_bessel_i0.default - couldn't find symbolic meta funct...
xfail('special.modified_bessel_i1', ''), # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct...
xfail('special.modified_bessel_k0', ''), # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct...
xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct...
xfail('special.ndtri', ''), # aten.special_ndtri.default - couldn't find symbolic meta function/decomposition
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/...
xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo...
xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo...
xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/decomposition
xfail('split', ''), # 'torch._C.SymIntNode' and 'int'
xfail('split', 'list_args'), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('split_with_sizes', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('symeig', ''), # aten.symeig.default - couldn't find symbolic meta function/decomposition
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition
xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('topk', ''), # aten.topk.default - couldn't find symbolic meta function/decomposition
xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
xfail('unfold', ''), # aten.unfold.default - couldn't find symbolic meta function/decomposition
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
}
symbolic_tensor_segfaults = {
}
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
def _test_make_fx_helper(self, device, dtype, op, tracing_mode):
def f(args, kwargs):
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
# Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
for sample_input in itertools.islice(sample_inputs_itr, 100):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
try:
new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
except DynamicOutputShapeException as e:
self.skipTest("Dynamic output shape operation in trace")
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)
try:
old_out = f(args, kwargs)
except Exception:
continue
new_out = wrapper_set_seed(new_f, args, kwargs)
self.assertEqual(new_out, old_out)
class TestProxyTensorOpInfo(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
def test_make_fx_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "real")
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
def test_make_fx_fake_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "fake")
@skipIfNoSympy
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "symbolic")
only_for = ("cpu")
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()
|