1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
|
# Owner(s): ["module: inductor"]
import itertools
import unittest
import torch
import torch._dynamo.testing
from torch._higher_order_ops.associative_scan import associative_scan
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import (
decorateIf,
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.testing._internal.triton_utils import requires_gpu
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
result = []
device = inputs[0].device
# iterate over the cartesian product of predicate values
for values in itertools.product(*([possible_values] * num_to_prepend)):
prepended = [torch.tensor(v, device=device) for v in values]
result.append((*prepended, *inputs))
return result
def prepend_predicates(inputs, num_predicates=1):
return _prepend_product_of_values(inputs, [False, True], num_predicates)
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
return _prepend_product_of_values(inputs, counter_values, num_counters)
class CondModels:
class Simple(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
return torch.cond(p, true_fn, false_fn, [a, b])
class SimpleWithIntClosure(torch.nn.Module):
def __init__(self):
super().__init__()
self.num = 3
def forward(self, p, a, b):
return torch.cond(
pred=p,
true_fn=lambda a, b: [a + b + self.num],
false_fn=lambda a, b: [a - b - self.num],
operands=(a, b),
)
class Nested(torch.nn.Module):
def forward(self, p0, p1, p2, a, b, c):
def true_fn(x0, y0, z0):
def true_true_fn(x1, y1, z1):
return (x1 - y1 * z1) * 3.14
def true_false_fn(x1, y1, z1):
def true_false_true_fn(x2, y2, z2):
return (x2 * y2 * z2) / 2.71
def true_false_false_fn(x2, y2, z2):
return (x2 + y2 + z2) * 1.23
return torch.cond(
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
)
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
def false_fn(x0, y0, z0):
def false_true_fn(x1, y1, z1):
def false_true_true_fn(x2, y2, z2):
return (x2 - y2 - z2) + 1.23
def false_true_false_fn(x2, y2, z2):
return (x2 / y2 / z2) - 3.14
return torch.cond(
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
)
def false_false_fn(x1, y1, z1):
return (x1 - y1 * z1) / 2.71
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
return torch.cond(p0, true_fn, false_fn, [a, b, c])
class Parameters(torch.nn.Module):
class InnerModel1(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer = torch.nn.Linear(20, 30, device=device)
def forward(self, x):
return self.layer(x + 1) * 3.14
class InnerModel2(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer1 = torch.nn.Linear(20, 10, device=device)
self.layer2 = torch.nn.Linear(10, 30, device=device)
def forward(self, x):
return self.layer2(self.layer1(x - 2)) * 3.14
def __init__(self, device):
super().__init__()
self.true_fn = self.InnerModel1(device)
self.false_fn = self.InnerModel2(device)
def forward(self, p, a):
return torch.cond(p, self.true_fn, self.false_fn, [a])
class ReinterpretView(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
z1 = x + y
z2 = x - y
return z1[2:], z2[:, 4:]
def false_fn(x, y):
z1 = x - y
z2 = x + y
return z1[2:], z2[:, 4:]
return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
class MultipleOutputs(torch.nn.Module):
def forward(self, p, a, b, c):
def true_fn(x, y, z):
return x * y, z / 2.71, (y - x).sum(dim=1)
def false_fn(x, y, z):
return y / x, z * 3.14, (x + y).mean(dim=1)
return torch.cond(p, true_fn, false_fn, [a, b, c])
class OuterCode(torch.nn.Module):
def forward(self, p, a, b):
c = a * b + 3.14
d = a / b - 2.71
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
e = torch.cond(p, true_fn, false_fn, [c, d])
return e * e / 1.41
class OuterBuffers(torch.nn.Module):
def forward(self, p, a, b, c):
d = a * 2
e = b / 2
def true_fn(x):
return x + d
def false_fn(x):
return x - e
return torch.cond(p, true_fn, false_fn, [c])
class WithNonTensorPredicate(torch.nn.Module):
def forward(self, a, b):
def true_fn(x, y):
return x.sum(0) / 3.14
def false_fn(x, y):
return y.sum(0) * 2.71
return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
class CondTests(TestCase):
def _run_test(
self,
model,
inputs,
device,
dynamic=False,
num_predicates=1,
):
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
inputs = [inp.to(device=device) for inp in inputs]
input_sets = [inputs]
if dynamic:
larger_inputs = []
for inp in inputs:
# tile every first dim 5x
tiling = [5] + [1] * (inp.ndim - 1)
larger_inputs.append(torch.tile(inp, tiling))
input_sets.append(larger_inputs)
for inputs in input_sets:
for inp in inputs:
# mark every first dim as dynamic
torch._dynamo.mark_dynamic(inp, 0)
for inputs in input_sets:
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
result = model(*inputs_with_predicates)
result_compiled = compiled_model(*inputs_with_predicates)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_predicates)
torch.testing.assert_close(result, result_compiled)
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_simple_control_flow(self, device, dynamic):
# cond control flow without nesting
self._run_test(
model=CondModels.Simple(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_simple_with_int_closure(self, device):
self._run_test(
model=torch.compile(CondModels.SimpleWithIntClosure(), dynamic=True),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
)
@requires_gpu
def test_cond_control_flow_with_precomputed_size(self):
class TestModel(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.conv2d = torch.nn.Conv2d(
512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
self.threshold = 20
def forward(self, x: torch.Tensor, index) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return self.conv2d(x)
def false_fn(x: torch.Tensor):
return self.conv2d(x)
return torch.cond(
index < self.threshold and index >= 0, true_fn, false_fn, (x,)
)
main_model = TestModel().to(GPU_TYPE)
x1 = torch.rand(2, 512, 128, 72).to(GPU_TYPE)
x2 = torch.rand(2, 512, 96, 96).to(GPU_TYPE)
opt_model = torch.compile(main_model)
out1 = main_model(x1, 1)
opt_out1 = opt_model(x1, 1)
self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
out2 = main_model(x2, 30)
opt_out2 = opt_model(x2, 30)
self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_nested_control_flow(self, device, dynamic):
# cond control flow with nesting
self._run_test(
model=CondModels.Nested(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
num_predicates=3,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_outer_code_before_after(self, device, dynamic):
# some code before and after the conditional
self._run_test(
model=CondModels.OuterCode(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_multiple_outputs(self, device, dynamic):
# multiple outputs with different shapes
self._run_test(
model=CondModels.MultipleOutputs(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(30, 40),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_advanced_dynamic_shapes(self, device):
# subgraphs input shapes include symbolic expressions
class Model(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
return torch.cat([x - 3, y * 3], dim=1)
def false_fn(x, y):
return torch.cat([x / 3, y - 3], dim=1)
c = torch.cat([a, b], dim=0)
d = c * 2
e = c / 2
return torch.cond(p, true_fn, false_fn, [d, e])
self._run_test(
model=Model(),
inputs=(
torch.randn(2, 3, 3),
torch.randn(4, 3, 3),
),
device=device,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_unbacked_symint_outer_to_inner(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
return torch.cos(x)
def false_fn(x):
return torch.sin(x)
nz = torch.nonzero(a)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.cond(p, true_fn, false_fn, [b])
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=Model(),
inputs=(torch.randn(2, 3, 3),),
device=device,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_unbacked_symint_inner(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.cos(b)
def false_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.sin(b)
b = torch.sin(a)
return torch.cond(p, true_fn, false_fn, [b])
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=Model(),
inputs=(torch.randn(2, 3, 3),),
device=device,
dynamic=True,
)
@unittest.skip("unbacked symints from inner to outer graph not supported yet")
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_unbacked_symint_inner_to_outer(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.cos(b)
def false_fn(x):
nz = torch.nonzero(x)
b = torch.ones([nz.size(0), 8], device=nz.device)
return torch.sin(b)
b = torch.sin(a)
y = torch.cond(p, true_fn, false_fn, [b])
return torch.sin(y)
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=Model(),
inputs=(torch.randn(2, 3, 3),),
device=device,
dynamic=True,
)
@requires_gpu
def test_cond_use_buffers_from_outer_scope(self):
# subgraphs input shapes include symbolic expressions
self._run_test(
model=CondModels.OuterBuffers(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=GPU_TYPE,
dynamic=False,
)
@requires_gpu
def test_cond_reintepret_view_inputs_outputs(self):
# ReinterpretView in inputs and outputs of the subgraphs
self._run_test(
model=CondModels.ReinterpretView(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=GPU_TYPE,
dynamic=True,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_subgraphs_with_parameters(self, device, dynamic):
# nested Modules with parameters
self._run_test(
model=CondModels.Parameters(device),
inputs=(torch.randn(10, 20),),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_cond_non_tensor_predicates(self, device, dynamic):
# model with a boolean predicate
for b_size_0 in [5, 15]:
torch._dynamo.reset()
self._run_test(
model=CondModels.WithNonTensorPredicate(),
inputs=(
torch.randn(10, 20),
torch.randn(b_size_0, 20),
),
device=device,
dynamic=dynamic,
num_predicates=0,
)
@requires_gpu
def test_cond_aliasing_outputs(self):
# output aliasing in subgraphs: not supported
class Model(torch.nn.Module):
def forward(self, p, a, b):
def true_fn(x, y):
z = x + y
return z, z[1:]
def false_fn(x, y):
z = x - y
return z, z[1:]
return torch.cond(p, true_fn, false_fn, [a, b])
# AssertionError: Output aliasing is currently not supported...
with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
torch.compile(Model())(
torch.tensor(True),
torch.randn(10, 20),
torch.randn(10, 20),
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_decompose_ops_in_subgraph(self, device):
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
return torch.zeros_like(x)
def false_fn(x):
return torch.ones_like(x)
b = torch.ones_like(a)
c = torch.cond(p, true_fn, false_fn, [b])
return c
self._run_test(
model=Model(),
inputs=(torch.rand(10, 20),),
device=device,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
def test_cond_decompose_ops_in_subgraph_recursive(self, device):
def inner_fn1(x):
return torch.zeros_like(x)
def inner_fn2(x):
return torch.ones_like(x)
class Model(torch.nn.Module):
def forward(self, p, a):
def true_fn(x):
return torch.cond(p, inner_fn2, inner_fn1, [x])
def false_fn(x):
return torch.cond(p, inner_fn1, inner_fn2, [x])
b = torch.ones_like(a)
c = torch.cond(p, true_fn, false_fn, [b])
return c
self._run_test(
model=Model(),
inputs=(torch.rand(10, 20),),
device=device,
)
@requires_gpu
def test_cond_inductor_fx_passes_recursively_applied(self):
counters = {"pre_grad": 0, "post_grad": 0}
def pre_grad_pass_counter(gm):
counters["pre_grad"] += 1
def post_grad_pass_counter(gm):
counters["post_grad"] += 1
with torch._inductor.config.patch(
{
"pre_grad_custom_pass": pre_grad_pass_counter,
"post_grad_custom_pre_pass": post_grad_pass_counter,
# The above patches don't pickle
"fx_graph_cache": False,
}
):
self._run_test(
model=CondModels.Nested(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
),
device=GPU_TYPE,
dynamic=True,
num_predicates=3,
)
self.assertEqual(counters["pre_grad"], 11)
self.assertEqual(counters["post_grad"], 11)
class WhileLoopModels:
class Simple(torch.nn.Module):
def forward(self, ci, a, b):
def cond_fn(i, x, y):
return i > 0
def body_fn(i, x, y):
return i - 1, x + y, y - x
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
class Nested(torch.nn.Module):
def forward(self, ci, cj, a, b):
def cond_fn(i1, j1, x1, y1):
return i1 > 0
def body_fn(i1, j1, x1, y1):
def cond_fn_nested(i2, j2, x2, y2):
return j2 > 0
def body_fn_nested(i2, j2, x2, y2):
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
i1, j1, x1, y1 = torch._higher_order_ops.while_loop(
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
)
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b))
class Parameters(torch.nn.Module):
class InnerModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer1 = torch.nn.Linear(20, 30, device=device)
self.layer2 = torch.nn.Linear(30, 20, device=device)
def forward(self, c, x):
return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
def __init__(self, device):
super().__init__()
self.body_fn = self.InnerModel(device)
self.cond_fn = lambda c, x: c > 0
def forward(self, c, a):
return torch._higher_order_ops.while_loop(
self.cond_fn, self.body_fn, [c, a]
)
class OuterCode(torch.nn.Module):
def forward(self, c, a, b):
d = a * b + 3.14
e = a / b - 2.71
def cond_fn(c, x, y):
return c > 0
def body_fn(c, x, y):
return c - 1, y - x, x + y
_, f, g = torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, d, e])
return f * g / 1.41
# TODO(aakhundov): add while_loop test with outer buffers
# with dynamic=True once dynamo / export allows while_loop
# closure capture with mark_dynamic:
# https://github.com/pytorch/pytorch/issues/123596
class OuterBuffers(torch.nn.Module):
def forward(self, c, a, b):
d = a * 2
e = b / 2
def cond_fn(c, x, y):
return c > 0
def body_fn(c, x, y):
return c - 1, x + d, y - e
return torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, a, b])
class PytreeCarry(torch.nn.Module):
def forward(self, it, pytree_input):
def cond_fn(it, pytree_input):
return it > 0
def body_fn(it, pytree_input):
x = pytree_input[0][0]
y = pytree_input[1]["x"]
z = pytree_input[1]["y"]
new_x = y.sin()
new_y = z.cos()
new_z = x + 1
return it - 1, ([new_x], {"x": new_y, "y": new_z})
return torch._higher_order_ops.while_loop(
cond_fn, body_fn, (it, pytree_input)
)
class DataDependentOpInSubgraph(torch.nn.Module):
def forward(self, c, a, b):
def cond_fn(c, reduced_carry):
return c > 0
def body_fn(c, reduced_carry):
k = torch.masked_select(a, b)
d = torch.concat([k, k * 2])
return c - 1, torch.min(d).unsqueeze(0) + reduced_carry
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, torch.zeros([1], dtype=torch.int64, device=c.device)],
)
class DataDependentInOut(torch.nn.Module):
def forward(self, c, a, b):
inp = torch.zeros(
a.sum().to(torch.int64).item(), 3, device=a.device, dtype=torch.int64
)
def cond_fn(c, inp):
return c > 0
def body_fn(c, inp):
return c - 1, (inp.sin() + 1).to(torch.int64)
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, inp],
)
class DataDependentInOutMismatch(torch.nn.Module):
def forward(self, c, a, b):
def cond_fn(c, a, b):
return c > 0
def body_fn(c, a, b):
return c - 1, a.nonzero(), b.nonzero()
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a, b],
)
class WhileLoopTests(TestCase):
def _run_test(
self,
model,
inputs,
device,
dynamic=False,
num_counters=1,
):
import torch.utils._pytree as pytree
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
inputs = pytree.tree_map(lambda t: t.to(device=device), inputs)
input_sets = [inputs]
if dynamic:
def mark_first_dim_dyn(inp):
torch._dynamo.mark_dynamic(inp, 0)
pytree.tree_map(mark_first_dim_dyn, input_sets)
def tile_fn(inp):
# tile every first dim 5x
tiling = [5] + [1] * (inp.ndim - 1)
t = torch.tile(inp, tiling)
# mark every first dim as dynamic
torch._dynamo.mark_dynamic(inp, 0)
return t
larger_inputs = pytree.tree_map(tile_fn, inputs)
input_sets.append(larger_inputs)
for inputs in input_sets:
flat_inputs, inp_spec = pytree.tree_flatten(inputs)
for flat_inputs_with_counters in prepend_counters(
flat_inputs, num_counters
):
counters, flat = (
flat_inputs_with_counters[:num_counters],
flat_inputs_with_counters[num_counters:],
)
unflat_inputs = pytree.tree_unflatten(flat, inp_spec)
inputs_with_counters = counters + unflat_inputs
cloned_inputs = pytree.tree_map(
lambda t: t.clone(), inputs_with_counters
)
result = model(*inputs_with_counters)
with torch.no_grad():
result_compiled = compiled_model(*inputs_with_counters)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
torch.testing.assert_close(
result, result_compiled, atol=1e-4, rtol=1e-4
)
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_while_loop_simple_control_flow(self, device, dynamic):
# while_loop control flow without nesting
self._run_test(
model=WhileLoopModels.Simple(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_while_loop_nested_control_flow(self, device, dynamic):
# while_loop control flow with nesting
self._run_test(
model=WhileLoopModels.Nested(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
num_counters=2,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_while_loop_with_outer_code(self, device, dynamic):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterCode(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
def test_while_loop_with_parameters(self, device, dynamic):
# while_loop control flow with parameters
self._run_test(
model=WhileLoopModels.Parameters(device),
inputs=(torch.randn(10, 20),),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
# dynamic=True doesn't work now due to
# https://github.com/pytorch/pytorch/issues/123596
@parametrize("dynamic", [False])
def test_while_loop_with_outer_buffers(self, device, dynamic):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterBuffers(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
# dynamic=True doesn't work due to we haven't handle lifted symbols
@parametrize("dynamic", [True, False])
def test_while_loop_with_pytree_inputs(self, device, dynamic):
self._run_test(
model=WhileLoopModels.PytreeCarry(),
inputs=(
(
[torch.randn(10, 20)],
{"x": torch.randn(10, 20), "y": torch.randn(10, 20)},
),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_while_loop_with_data_dependent_ops(self, device, dynamic):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=WhileLoopModels.DataDependentOpInSubgraph(),
inputs=(
torch.tensor([1, 2, 3, 4, 5]),
torch.tensor(
[True, True, True, True, True],
),
),
device=device,
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_while_loop_with_data_dependent_in_out(self, device, dynamic):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
"capture_scalar_outputs": True,
}
):
self._run_test(
model=WhileLoopModels.DataDependentInOut(),
inputs=(
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
torch.tensor(
[True, True, True, True, True],
),
),
device=device,
dynamic=dynamic,
)
@parametrize("dynamic", [True, False])
def test_while_loop_with_data_dependent_in_out_mismatch(self, dynamic):
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
r"while_loop doesn't work unless it is captured completely with torch.compile",
):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
}
):
self._run_test(
model=WhileLoopModels.DataDependentInOutMismatch(),
inputs=(
torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]),
torch.tensor(
[True, True, True, True, True],
),
),
device="cpu",
dynamic=dynamic,
)
class AssociativeScanTests(TestCase):
@requires_gpu
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("backend", ["inductor"])
@parametrize("device", [torch.device("cpu"), GPU_TYPE])
# This test will fail as flip in combination with particular input lenghts
# produces weird results.
# This is under investigations in
# https://github.com/pytorch/pytorch/issues/131805
@decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE)
def test_associative_scan_CUDA_flip(self, combine_mode, backend, device):
def fct(x: torch.Tensor, y: torch.Tensor):
return x + y
# for n in range(10):
for n in [9]:
x = torch.arange(n, device=device)
torch.compiler.reset()
associative_scan1 = torch.compile(
associative_scan, backend=backend, fullgraph=True
)
associative_scan2 = associative_scan
if combine_mode == "pointwise" and device == torch.device("cpu"):
with self.assertRaisesRegex(Exception, r"."):
associative_scan1(
fct, x, 0, reverse=False, combine_mode=combine_mode
)
# Skipping test because combine_mode currently only suppors CUDA tensors
return
result1 = associative_scan1(
fct, x, 0, reverse=False, combine_mode=combine_mode
)
result2 = associative_scan2(
fct, x, 0, reverse=False, combine_mode=combine_mode
)
result3 = torch.cumsum(x, 0)
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Flip only non-compiled and compare with compiled reverse=True
result1 = associative_scan1(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result2 = torch.flip(
associative_scan2(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Flip only compiled and compare with non-compiled reverse=True
result1 = torch.flip(
associative_scan1(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result2 = associative_scan2(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Use reverse=False, but flip both results before and after
result1 = torch.flip(
associative_scan1(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result2 = torch.flip(
associative_scan2(
fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
),
[0],
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
# Reverse=True
result1 = associative_scan1(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result2 = associative_scan2(
fct, x, 0, reverse=True, combine_mode=combine_mode
)
result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
self.assertEqual(result1, result2)
self.assertEqual(result1, result3)
instantiate_parametrized_tests(CondTests)
instantiate_parametrized_tests(WhileLoopTests)
instantiate_parametrized_tests(AssociativeScanTests)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")
|