1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
|
# Owner(s): ["oncall: jit"]
import io
import os
import sys
import torch
from torch.testing import FileCheck
from enum import Enum
from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestUnion(JitTestCase):
"""
This class tests the functionality of `Union`.
Note: It's important to be able to refine the type of a `Union` to
one of its internal types. Currently, there are differences in the
way Python expects `isinstance` checks and the way TorchScript
expects `isinstance` checks. This means that we can't use
`checkScript` in our test cases because either the eager mode or the
script mode wouldn't run! So, some test cases have separate but
equivalent functions to emulate `checkScript`.
"""
def test_check_union_annotation(self):
def test_func(a: Union[int, float], b: Optional[int]):
return 0
scripted_func = torch.jit.script(test_func)
graph_rep = str(scripted_func.graph)
code_rep = str(scripted_func.code)
# TS graph IR for Union should be annotated as Union()
FileCheck().check("Union(").check("int?").run(graph_rep)
# Serialized code for Union should be annotated as Union[]
FileCheck().check("Union[").check("Optional[int]").run(code_rep)
self.checkScript(test_func, (5, 6))
# this shouldn't error out
torch._C.parse_ir(str(scripted_func.graph))
def test_union_with_scalar_values(self):
def fn(x: Union[int, float]) -> str:
return "foo"
self.checkScript(fn, (1,))
self.checkScript(fn, (1.0,))
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[float, int\] but "
"instead found type str"):
scripted("1")
def test_union_with_collections(self):
def fn(x: Union[Dict[str, int], List[int]]) -> str:
return "foo"
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
self.checkScript(fn, ([1, 2, 3],))
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"Dict\[str, str\]"):
scripted({"foo": "bar", "baz": "qux"})
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"List\[str\]"):
scripted(["foo", "bar", "baz"])
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
"str"):
scripted("1")
def test_union_with_enum(self):
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
def fn(x: Union[str, Color]) -> str:
return "foo"
self.checkScript(fn, (Color.RED,))
self.checkScript(fn, ("red",))
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[__torch__.jit.test_union."
r"Color, str\] but instead found "
"type int"):
scripted(1)
def test_union_in_class_constructor(self):
@torch.jit.script # noqa: B903
class A(object): # noqa: B903
def __init__(self, x: Union[int, str]) -> None:
self.x = x
def fn(x: Union[str, int]) -> A:
return A(x)
self.assertEqual(fn("foo").x, "foo")
self.assertEqual(fn(1).x, 1)
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[int, str\] but instead "
r"found type List\[str\]"):
scripted(["foo", "bar", "baz"])
def test_union_return_type(self):
def fn(x: int) -> Union[int, str]:
return "foo"
self.checkScript(fn, (1,))
def test_union_as_annotation(self):
def fn() -> Union[int, str]:
x: Union[int, str] = "foo"
return x
self.checkScript(fn, ())
def test_union_as_annotation_in_typed_container(self):
def fn() -> None:
l: List[Union[int, str]] = []
u1: Union[int, str] = "foo"
u2: Union[int, str] = 1
l.append(u1)
l.append(u2)
self.checkScript(fn, ())
def test_union_as_annotation_py2(self):
def fn():
# type: () -> Union[int, str]
x: Union[int, str] = "foo"
return x
self.checkScript(fn, ())
def test_union_as_internal_tuple_type(self):
def fn():
t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
return t
self.checkScript(fn, ())
def test_union_variable_can_be_reassigned(self):
@torch.jit.script
def aux1(i: int):
return int(i ** 2)
@torch.jit.script
def aux2(s: str):
return s + s
def fn() -> Union[int, str]:
x: Union[int, str] = "foo"
i: int = 1
x = i
y: int = aux1(x)
z: str = aux2(str(y))
x = z
return x
self.checkScript(fn, ())
def test_union_does_not_replace_existing_annotated_type(self):
def fn():
x: List[int] = [1, 2, 3]
x.append("foo")
return x
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
scripted = torch.jit.script(fn)
scripted()
def test_union_does_not_replace_existing_annotated_type_union(self):
def fn():
x: List[Union[int, str]] = [1, "foo", 3]
x.append(2.0)
return x
with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
scripted = torch.jit.script(fn)
scripted()
def test_union_does_not_replace_existing_annotated_type_empty_container(self):
def fn():
x: List[int] = []
x.append("foo")
return x
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
scripted = torch.jit.script(fn)
scripted()
def test_unions_of_unions_are_flattened(self):
@torch.jit.script
def fn(x: Union[Union[int, str], float]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union(float, int, str)") \
.run(s)
def test_unions_of_a_single_argument_vanish(self):
@torch.jit.script
def fn(x: Union[int]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : int") \
.run(s)
def test_union_redundant_arguments_are_skipped(self):
@torch.jit.script
def fn(x: Union[int, str, int]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union(int, str)") \
.run(s)
def test_union_redundant_arguments_are_skipped_optional(self):
@torch.jit.script
def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union(float, int, NoneType)") \
.run(s)
def test_union_redundant_arguments_are_skipped_subtyping(self):
@torch.jit.script
def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union((int?, int), str)") \
.run(s)
def test_union_redundant_arguments_are_skipped_container(self):
@torch.jit.script
def fn(x: Union[List[str], List[float], List[str]]) -> str:
return "foo"
s = fn.graph
FileCheck().check("x : Union(float[], str[])") \
.run(s)
def test_union_argument_order_is_ignored(self):
@torch.jit.script
def fn1(x: Union[int, str]) -> str:
return "foo"
@torch.jit.script
def fn2(x: Union[str, int]) -> str:
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union(int, str)") \
.run(s)
def test_union_argument_order_is_ignored_container(self):
@torch.jit.script
def fn1(x: Union[List[str], List[int]]) -> str:
return "foo"
@torch.jit.script
def fn2(x: Union[List[int], List[str]]) -> str:
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union(int[], str[])") \
.run(s)
def test_union_T_None_is_equivalent_to_optional_T(self):
@torch.jit.script
def inner(x: Union[int, None]) -> int:
if x is not None:
return x
else:
return 5
@torch.jit.script
def fn1() -> int:
a: Optional[int] = 5
b: Optional[int] = None
a_ = inner(a)
b_ = inner(b)
return a_ + b_
self.assertEqual(fn1(), 10)
@torch.jit.script
def inner2(x: Optional[int]) -> int:
if x is not None:
return x
else:
return 5
@torch.jit.script
def fn2() -> int:
a: Union[int, None] = 5
b: Union[int, None] = None
a_ = inner(a)
b_ = inner(b)
return a_ + b_
self.assertEqual(fn2(), 10)
def test_union_optional_of_union_is_flattened(self):
@torch.jit.script
def fn(flag: int) -> Union[str, int, None]:
y: Union[int, str, None] = "foo"
if flag == 0:
x: Optional[Union[int, str]] = y
elif flag == 1:
x: Optional[Union[int, str]] = 1
else:
x: Optional[Union[int, str]] = None
return x
# Can't use `checkScript` because it will flag the fact that
# the original code has `Optional[Union[int, str]]` but the
# saved/loaded code has `Union[int, NoneType, str]` (even
# though this is exactly what we want)
self.assertEqual(fn(0), "foo")
self.assertEqual(fn(1), 1)
self.assertEqual(fn(2), None)
buffer = io.BytesIO()
torch.jit.save(fn, buffer)
buffer = io.BytesIO(buffer.getvalue())
l = torch.jit.load(buffer)
s = l.code
FileCheck().check("Union[int, NoneType, str]") \
.check("Union[int, NoneType, str]") \
.run(s)
def test_union_subclasses_larger_union(self):
def fn() -> Union[int, str, torch.Tensor]:
x: Union[int, str] = "foo"
return x
self.checkScript(fn, ())
# TODO: We would like to eventually support this. The issue is being
# tracked at https://github.com/pytorch/pytorch/issues/58167
def test_union_as_dict_key(self):
def fn():
x: Dict[Union[int, str], str] = {}
x["foo"] = "bar"
x[1] = 2
return x[1]
with self.assertRaisesRegex(RuntimeError, "only int, float, "
"complex, Tensor, device and string keys "
"are supported"):
torch.jit.script(fn)
def test_union_as_dict_value(self):
def fn():
x: Dict[str, Union[int, str]] = {}
x["foo"] = "bar"
x["baz"] = 2
return x["baz"]
self.checkScript(fn, ())
def test_union_module_with_union_instance_variable(self):
class M(torch.nn.Module):
x: Union[int, str]
def __init__(self, x: Union[int, str]):
super().__init__()
self.x: Union[int, str] = x
def forward(self, y: Union[int, str]):
self.x = y
return self.x
self.checkModule(M(2,), (1,))
self.checkModule(M("bar"), ("foo",))
def test_union_module_with_union_class_variable(self):
class M(torch.nn.Module):
x: Union[int, str] = "foo"
def __init__(self, y: int):
super().__init__()
x = y
def forward(self, z: str):
x = z
return x
self.checkModule(M(1), ("foo",))
def test_union_type_refinement(self):
def fn(x: Union[int, str]) -> str:
if isinstance(x, str):
z = x + "bar"
return x
else:
return "baz"
self.checkScript(fn, ("foo",))
self.checkScript(fn, (1,))
def test_union_type_refinement_union_rhs(self):
def fn(x: int) -> str:
if torch.jit.isinstance(x, Union[int, str]):
return "bar"
else:
return "baz"
self.checkScript(fn, (1,))
def test_union_type_refinement_tuple_rhs(self):
def fn(x: Union[int, float, List[str]]) -> str:
if isinstance(x, (int, float)):
if isinstance(x, int):
return str(x)
else:
return "foo"
else:
if len(x):
return x[0]
else:
return "bar"
self.checkScript(fn, (1,))
self.checkScript(fn, (1.0,))
self.checkScript(fn, (["a", "b", "c"],))
def test_union_type_refinement_tuple_rhs_noncontained_type(self):
def fn(x: Union[int, List[str]]) -> str:
if isinstance(x, (int, float)):
y = x + x
return str(y)
else:
if len(x):
return x[0]
else:
return "bar"
self.checkScript(fn, (1,))
self.checkScript(fn, (["a", "b", "c"],))
def test_union_type_refinement_tuple_rhs_union(self):
@torch.jit.script
def fn(x: int) -> str:
if torch.jit.isinstance(x, (Union[int, str], float)):
y = x + x
return str(y)
else:
return "foo"
# TODO: There's currently an unrelated bug in
# `torch.jit.isinstance` that makes it fail for tuple literals.
# Posted here: https://github.com/pytorch/pytorch/issues/60095
# Change `assertEqual` to `checkScript` when the bug is fixed
self.assertEqual(fn(1), "2")
def test_union_type_refinement_statically_false(self):
@torch.jit.script
def fn(x: int) -> str:
if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
z = x + "foo"
return z
else:
return "bar"
s = fn.graph
# Check that we don't have any branching statements
FileCheck().check_not("block0()") \
.check_not("block1()") \
.run(s)
def test_union_type_refinement_statically_true(self):
@torch.jit.script
def fn(x: Union[List[int], int]) -> Union[List[int], int]:
if not torch.jit.isinstance(x, (int, List[int])):
return x
else:
l = [1, 2, 3]
y: Union[List[int], int] = l
return y
s = fn.graph
# Check that we don't have any branching statements
FileCheck().check_not("block0()") \
.check_not("block1()") \
.run(s)
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
def fn(x: Union[List[int], int]) -> int:
if torch.jit.isinstance(x, (int, float, str)):
# We should know that `x` is an `int` here
z = x + 1
return z
else:
return 100
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn, (1,))
def test_union_type_refinement_partial_static_refinement_union_rhs(self):
def fn(x: Union[List[int], int]) -> int:
if torch.jit.isinstance(x, Union[int, float, str]):
# We should know that `x` is an `int` here
z = x + 1
return z
else:
return 100
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn, (1,))
def test_union_type_refinement_internal_declaration(self):
def fn(flag: bool) -> str:
x: Union[int, str, None] = None
if (flag):
y = "foo"
else:
y = 1
if isinstance(x, str):
return x
else:
return "bar"
self.checkScript(fn, (True,))
self.checkScript(fn, (False,))
def test_union_branching_with_union_return_and_homogenous_types(self):
def fn(x: int) -> Union[int, str]:
if x % 2:
return "foo"
else:
return "bar"
self.checkScript(fn, (1,))
self.checkScript(fn, (8,))
def test_union_branching_does_not_autoinfer_undeclared_union(self):
def fn(x: int) -> str:
if x % 2:
y = "foo"
else:
y = x
if isinstance(y, str):
return y
else:
return "bar"
with self.assertRaisesRegex(RuntimeError, "y is set to type str"
" in the true branch and type int "
"in the false branch"):
torch.jit.script(fn)
def test_union_branching_does_not_widen_existing_inferred_type(self):
def fn(x: int) -> str:
y = "foo"
if x % 2:
y = "bar"
else:
y = x
if isinstance(y, str):
return y
else:
return "baz"
with self.assertRaisesRegex(RuntimeError, "previously had type "
"str but is now being assigned to a"
" value of type int"):
torch.jit.script(fn)
def test_union_schema_matching_on_internal_type(self):
def fn(x: Union[List[int], Dict[str, int]]) -> int:
if torch.jit.isinstance(x, List[int]):
return x[0]
else:
return list(x.values())[0]
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
def test_union_subtractive_refinement(self):
def fn(x: Union[List[int], int]) -> int:
if not isinstance(x, int):
x.append(1)
return x[0]
else:
return x
self.checkScript(fn, (1,))
self.checkScript(fn, ([1, 2, 3],))
def test_union_subtractive_refinement_with_container(self):
def fn(x: Union[List[int], int]) -> int:
if not torch.jit.isinstance(x, List[int]):
return x
else:
x.append(1)
return x[0]
self.checkScript(fn, (1,))
self.checkScript(fn, ([1, 2, 3],))
def test_union_memory_aliasing(self):
def fn():
x : List[torch.Tensor] = []
z : List[Optional[List[torch.Tensor]]] = []
z.append(x)
x_alias = z[0]
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
x_alias.append(torch.tensor(3))
return x
self.checkScript(fn, ())
def test_union_serialization_preserves_type_annotations(self):
# This function will fail after being torch.jit.save'd and
# torch.jit.load'd if the type annotations aren't preserved
# for Union during serialization. We need the `Union[str, int]`
# annotation to make sure that `y` is typed as a Union instead
# of as a str in one branch and an int in the other
def fn(x: int) -> str:
if x % 2:
y: Union[str, int] = "bar"
else:
y: Union[str, int] = x
if isinstance(y, str):
return y
else:
return "baz"
self.checkScript(fn, (1,))
self.checkScript(fn, (8,))
def _assert_passes(self, template: str, ann: str, lhs: str):
code = template.format(ann=ann, lhs=lhs)
self.checkScript(code, (), name="fn")
def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
code = template.format(ann=ann, lhs=lhs)
with self.assertRaisesRegex(RuntimeError, msg):
cu = torch.jit.CompilationUnit(code, _frames_up=1)
string_frontend = getattr(cu, "fn") # noqa: B009
def test_union_with_list_assignment(self):
template = dedent('''
def fn():
x: {ann} = {lhs}
if torch.jit.isinstance(x, List[torch.Tensor]):
x.append(torch.tensor(3))
return x
''')
lhs = {"list_literal_empty" : "[]",
"list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",
"list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",
"list_literal_of_mixed" : "[torch.arange(5), 1]",
"list_comprehension_of_tensor" :
"[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
"list_comprehension_of_str" :
"[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",
"list_comprehension_of_mixed" :
"[torch.add(1, x) for x in [torch.arange(5), 1]]"}
"""
Union[List[str], List[torch.Tensor]]
"""
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_empty"],
"there are multiple possible List type "
"candidates in the Union annotation")
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_tensor"])
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_str"])
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_mixed"],
"none of those types match the types of the"
" given list elements")
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_tensor"])
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_str"])
# TODO: Support mixed list comprehensions
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid")
"""
Union[int, torch.Tensor]
"""
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["list_literal_empty"],
"Expected an Union type annotation with an "
"inner List type")
self._assert_raises(template, "Union[int, torch.Tensor]",
lhs["list_literal_of_tensor"],
"Expected an Union type annotation with an "
"inner List type")
self._assert_raises(template, "Union[int, torch.Tensor]",
lhs["list_comprehension_of_tensor"],
"Expected an Union type annotation with an "
"inner List type")
"""
Union[List[torch.Tensor], int]
"""
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_empty"])
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_of_tensor"])
self._assert_raises(template, "Union[List[torch.Tensor], int]",
lhs["list_literal_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")
self._assert_raises(template, "Union[List[torch.Tensor], int]",
lhs["list_literal_of_mixed"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_tensor"])
self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")
# TODO(@ansley): Support mixed list comprehensions
self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid")
def test_union_with_dict_assignment(self):
template = dedent('''
def fn():
x: {ann} = {lhs}
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
x["foo"] = torch.tensor(3)
return x
''')
lhs = {"dict_literal_empty" : "{}",
"dict_literal_of_str_tensor" :
"{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",
"dict_literal_of_str_int" :
"{\"foo\" : 1, \"bar\" : 2}",
"dict_literal_of_mixed" :
"{\"foo\" : torch.arange(3), \"bar\" : 2}",
"dict_comprehension_of_str_tensor" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",
"dict_comprehension_of_str_int" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [1, 2]}",
"dict_comprehension_of_mixed" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",
"dict_keyword" :
"dict(foo=torch.arange(3), baz=torch.arange(5))",
"dict_keyword_with_iterable" :
"dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])",
"dict_keyword_with_empty_iterable" :
"dict([])",
"dict_keyword_with_internal_aggregate_function" :
"dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])",
"dict_keyword_with_mapping" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})",
"dict_keyword_with_mapping_and_kwargs" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))",
}
"""
Union[Dict[str, torch.Tensor], Dict[str, int]]
"""
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with an "
"inner Dict type")
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_tensor"])
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_int"])
self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"],
"none of those dict types can hold the "
"types of the given keys and values")
# TODO: String frontend does not support tuple unpacking
# https://github.com/pytorch/pytorch/issues/64096
# self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_comprehension_of_str_tensor"])
# self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_comprehension_of_str_int"])
# self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_comprehension_of_mixed"],
# "foobar")
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_keyword_with_internal_aggregate_function"])
# TODO(@ansley): Follow-up project needed for full type
# inference with dict keyword (supported for dict comprehension
# and dict literal already; should not be a blocker for anyone)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword"],
"full type inference is not yet supported")
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_iterable"],
"full type inference is not yet supported")
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_empty_iterable"],
"full type inference is not yet supported")
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping"],
"full type inference is not yet supported")
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping_and_kwargs"],
"full type inference is not yet supported")
"""
Union[int, torch.Tensor]
"""
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with "
"an inner Dict type")
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["dict_literal_of_str_tensor"],
"Expected an Union type annotation with "
"an inner Dict type")
# See above--string frontend does not support tuple unpacking
# self._assert_raises(template, "Union[int, torch.Tensor]",
# lhs["dict_comprehension_of_tensor"],
# "foobar")
"""
Union[Dict[str, torch.Tensor], int]
"""
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_empty"])
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_tensor"])
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_int"],
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is")
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_mixed"],
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is")
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword"])
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_iterable"])
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_empty_iterable"])
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping"])
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping_and_kwargs"])
# See above--string frontend does not support tuple unpacking
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_keyword_with_internal_aggregate_function"])
#
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_str_tensor"])
# self._assert_raises(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_str_int"],
# "foobar")
# self._assert_raises(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_mixed"],
# "foobar")
|