# Owner(s): ["oncall: jit"]

import io
import os
import sys

import torch
from torch.testing import FileCheck
from enum import Enum
from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global

if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_jit.py TESTNAME\n\n"
                       "instead.")

class TestUnion(JitTestCase):
    """
    This class tests the functionality of `Union`.

    Note: It's important to be able to refine the type of a `Union` to
    one of its internal types. Currently, there are differences in the
    way Python expects `isinstance` checks and the way TorchScript
    expects `isinstance` checks. This means that we can't use
    `checkScript` in our test cases because either the eager mode or the
    script mode wouldn't run! So, some test cases have separate but
    equivalent functions to emulate `checkScript`.
    """

    def test_check_union_annotation(self):
        def test_func(a: Union[int, float], b: Optional[int]):
            return 0

        scripted_func = torch.jit.script(test_func)
        graph_rep = str(scripted_func.graph)
        code_rep = str(scripted_func.code)
        # TS graph IR for Union should be annotated as Union()
        FileCheck().check("Union(").check("int?").run(graph_rep)
        # Serialized code for Union should be annotated as Union[]
        FileCheck().check("Union[").check("Optional[int]").run(code_rep)
        self.checkScript(test_func, (5, 6))
        # this shouldn't error out
        torch._C.parse_ir(str(scripted_func.graph))

    def test_union_with_scalar_values(self):
        def fn(x: Union[int, float]) -> str:
            return "foo"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (1.0,))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[float, int\] but "
                                    "instead found type str"):
            scripted("1")

    def test_union_with_collections(self):
        def fn(x: Union[Dict[str, int], List[int]]) -> str:
            return "foo"

        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
        self.checkScript(fn, ([1, 2, 3],))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[List\[int\], Dict\[str, "
                                    r"int\]\] but instead found type "
                                    r"Dict\[str, str\]"):
            scripted({"foo": "bar", "baz": "qux"})

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[List\[int\], Dict\[str, "
                                    r"int\]\] but instead found type "
                                    r"List\[str\]"):
            scripted(["foo", "bar", "baz"])

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[List\[int\], Dict\[str, "
                                    r"int\]\] but instead found type "
                                    "str"):
            scripted("1")

    def test_union_with_enum(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        def fn(x: Union[str, Color]) -> str:
            return "foo"

        self.checkScript(fn, (Color.RED,))
        self.checkScript(fn, ("red",))

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[__torch__.jit.test_union."
                                    r"Color, str\] but instead found "
                                    "type int"):
            scripted(1)

    def test_union_in_class_constructor(self):

        @torch.jit.script  # noqa: B903
        class A(object):    # noqa: B903
            def __init__(self, x: Union[int, str]) -> None:
                self.x = x

        def fn(x: Union[str, int]) -> A:
            return A(x)

        self.assertEqual(fn("foo").x, "foo")
        self.assertEqual(fn(1).x, 1)

        scripted = torch.jit.script(fn)

        with self.assertRaisesRegex(RuntimeError, "Expected a member of"
                                    r" Union\[int, str\] but instead "
                                    r"found type List\[str\]"):
            scripted(["foo", "bar", "baz"])

    def test_union_return_type(self):
        def fn(x: int) -> Union[int, str]:
            return "foo"

        self.checkScript(fn, (1,))

    def test_union_as_annotation(self):
        def fn() -> Union[int, str]:
            x: Union[int, str] = "foo"
            return x

        self.checkScript(fn, ())

    def test_union_as_annotation_in_typed_container(self):
        def fn() -> None:
            l: List[Union[int, str]] = []
            u1: Union[int, str] = "foo"
            u2: Union[int, str] = 1
            l.append(u1)
            l.append(u2)

        self.checkScript(fn, ())

    def test_union_as_annotation_py2(self):
        def fn():
            # type: () -> Union[int, str]
            x: Union[int, str] = "foo"
            return x

        self.checkScript(fn, ())

    def test_union_as_internal_tuple_type(self):
        def fn():
            t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
            return t

        self.checkScript(fn, ())

    def test_union_variable_can_be_reassigned(self):
        @torch.jit.script
        def aux1(i: int):
            return int(i ** 2)

        @torch.jit.script
        def aux2(s: str):
            return s + s

        def fn() -> Union[int, str]:
            x: Union[int, str] = "foo"
            i: int = 1
            x = i
            y: int = aux1(x)
            z: str = aux2(str(y))
            x = z
            return x

        self.checkScript(fn, ())

    def test_union_does_not_replace_existing_annotated_type(self):
        def fn():
            x: List[int] = [1, 2, 3]
            x.append("foo")
            return x

        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
            scripted = torch.jit.script(fn)
            scripted()

    def test_union_does_not_replace_existing_annotated_type_union(self):
        def fn():
            x: List[Union[int, str]] = [1, "foo", 3]
            x.append(2.0)
            return x

        with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
            scripted = torch.jit.script(fn)
            scripted()

    def test_union_does_not_replace_existing_annotated_type_empty_container(self):
        def fn():
            x: List[int] = []
            x.append("foo")
            return x

        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
            scripted = torch.jit.script(fn)
            scripted()

    def test_unions_of_unions_are_flattened(self):
        @torch.jit.script
        def fn(x: Union[Union[int, str], float]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(float, int, str)")    \
                   .run(s)

    def test_unions_of_a_single_argument_vanish(self):
        @torch.jit.script
        def fn(x: Union[int]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : int")    \
                   .run(s)

    def test_union_redundant_arguments_are_skipped(self):
        @torch.jit.script
        def fn(x: Union[int, str, int]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(int, str)")    \
                   .run(s)

    def test_union_redundant_arguments_are_skipped_optional(self):
        @torch.jit.script
        def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(float, int, NoneType)")    \
                   .run(s)

    def test_union_redundant_arguments_are_skipped_subtyping(self):
        @torch.jit.script
        def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union((int?, int), str)")    \
                   .run(s)

    def test_union_redundant_arguments_are_skipped_container(self):
        @torch.jit.script
        def fn(x: Union[List[str], List[float], List[str]]) -> str:
            return "foo"

        s = fn.graph

        FileCheck().check("x : Union(float[], str[])")     \
                   .run(s)

    def test_union_argument_order_is_ignored(self):
        @torch.jit.script
        def fn1(x: Union[int, str]) -> str:
            return "foo"

        @torch.jit.script
        def fn2(x: Union[str, int]) -> str:
            return "foo"

        for s in (fn1.graph, fn2.graph):
            FileCheck().check("x : Union(int, str)")     \
                .run(s)

    def test_union_argument_order_is_ignored_container(self):
        @torch.jit.script
        def fn1(x: Union[List[str], List[int]]) -> str:
            return "foo"

        @torch.jit.script
        def fn2(x: Union[List[int], List[str]]) -> str:
            return "foo"

        for s in (fn1.graph, fn2.graph):
            FileCheck().check("x : Union(int[], str[])")     \
                .run(s)

    def test_union_T_None_is_equivalent_to_optional_T(self):
        @torch.jit.script
        def inner(x: Union[int, None]) -> int:
            if x is not None:
                return x
            else:
                return 5

        @torch.jit.script
        def fn1() -> int:
            a: Optional[int] = 5
            b: Optional[int] = None
            a_ = inner(a)
            b_ = inner(b)
            return a_ + b_

        self.assertEqual(fn1(), 10)

        @torch.jit.script
        def inner2(x: Optional[int]) -> int:
            if x is not None:
                return x
            else:
                return 5

        @torch.jit.script
        def fn2() -> int:
            a: Union[int, None] = 5
            b: Union[int, None] = None
            a_ = inner(a)
            b_ = inner(b)
            return a_ + b_

        self.assertEqual(fn2(), 10)

    def test_union_optional_of_union_is_flattened(self):
        @torch.jit.script
        def fn(flag: int) -> Union[str, int, None]:
            y: Union[int, str, None] = "foo"
            if flag == 0:
                x: Optional[Union[int, str]] = y
            elif flag == 1:
                x: Optional[Union[int, str]] = 1
            else:
                x: Optional[Union[int, str]] = None
            return x

        # Can't use `checkScript` because it will flag the fact that
        # the original code has `Optional[Union[int, str]]` but the
        # saved/loaded code has `Union[int, NoneType, str]` (even
        # though this is exactly what we want)
        self.assertEqual(fn(0), "foo")
        self.assertEqual(fn(1), 1)
        self.assertEqual(fn(2), None)

        buffer = io.BytesIO()
        torch.jit.save(fn, buffer)
        buffer = io.BytesIO(buffer.getvalue())
        l = torch.jit.load(buffer)

        s = l.code

        FileCheck().check("Union[int, NoneType, str]")     \
                   .check("Union[int, NoneType, str]")     \
                   .run(s)

    def test_union_subclasses_larger_union(self):
        def fn() -> Union[int, str, torch.Tensor]:
            x: Union[int, str] = "foo"
            return x

        self.checkScript(fn, ())

    # TODO: We would like to eventually support this. The issue is being
    # tracked at https://github.com/pytorch/pytorch/issues/58167
    def test_union_as_dict_key(self):
        def fn():
            x: Dict[Union[int, str], str] = {}
            x["foo"] = "bar"
            x[1] = 2
            return x[1]

        with self.assertRaisesRegex(RuntimeError, "only int, float, "
                                    "complex, Tensor, device and string keys "
                                    "are supported"):
            torch.jit.script(fn)

    def test_union_as_dict_value(self):
        def fn():
            x: Dict[str, Union[int, str]] = {}
            x["foo"] = "bar"
            x["baz"] = 2
            return x["baz"]

        self.checkScript(fn, ())

    def test_union_module_with_union_instance_variable(self):
        class M(torch.nn.Module):

            x: Union[int, str]

            def __init__(self, x: Union[int, str]):
                super().__init__()
                self.x: Union[int, str] = x

            def forward(self, y: Union[int, str]):
                self.x = y
                return self.x

        self.checkModule(M(2,), (1,))
        self.checkModule(M("bar"), ("foo",))

    def test_union_module_with_union_class_variable(self):
        class M(torch.nn.Module):
            x: Union[int, str] = "foo"

            def __init__(self, y: int):
                super().__init__()
                x = y

            def forward(self, z: str):
                x = z
                return x

        self.checkModule(M(1), ("foo",))

    def test_union_type_refinement(self):
        def fn(x: Union[int, str]) -> str:
            if isinstance(x, str):
                z = x + "bar"
                return x
            else:
                return "baz"

        self.checkScript(fn, ("foo",))
        self.checkScript(fn, (1,))

    def test_union_type_refinement_union_rhs(self):
        def fn(x: int) -> str:
            if torch.jit.isinstance(x, Union[int, str]):
                return "bar"
            else:
                return "baz"

        self.checkScript(fn, (1,))

    def test_union_type_refinement_tuple_rhs(self):
        def fn(x: Union[int, float, List[str]]) -> str:
            if isinstance(x, (int, float)):
                if isinstance(x, int):
                    return str(x)
                else:
                    return "foo"
            else:
                if len(x):
                    return x[0]
                else:
                    return "bar"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (1.0,))
        self.checkScript(fn, (["a", "b", "c"],))

    def test_union_type_refinement_tuple_rhs_noncontained_type(self):
        def fn(x: Union[int, List[str]]) -> str:
            if isinstance(x, (int, float)):
                y = x + x
                return str(y)
            else:
                if len(x):
                    return x[0]
                else:
                    return "bar"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (["a", "b", "c"],))

    def test_union_type_refinement_tuple_rhs_union(self):
        @torch.jit.script
        def fn(x: int) -> str:
            if torch.jit.isinstance(x, (Union[int, str], float)):
                y = x + x
                return str(y)
            else:
                return "foo"

        # TODO: There's currently an unrelated bug in
        # `torch.jit.isinstance` that makes it fail for tuple literals.
        # Posted here: https://github.com/pytorch/pytorch/issues/60095
        # Change `assertEqual` to `checkScript` when the bug is fixed
        self.assertEqual(fn(1), "2")

    def test_union_type_refinement_statically_false(self):
        @torch.jit.script
        def fn(x: int) -> str:
            if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
                z = x + "foo"
                return z
            else:
                return "bar"

        s = fn.graph

        # Check that we don't have any branching statements
        FileCheck().check_not("block0()")    \
            .check_not("block1()")           \
            .run(s)

    def test_union_type_refinement_statically_true(self):
        @torch.jit.script
        def fn(x: Union[List[int], int]) -> Union[List[int], int]:
            if not torch.jit.isinstance(x, (int, List[int])):
                return x
            else:
                l = [1, 2, 3]
                y: Union[List[int], int] = l
                return y

        s = fn.graph

        # Check that we don't have any branching statements
        FileCheck().check_not("block0()")    \
            .check_not("block1()")           \
            .run(s)

    def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
        def fn(x: Union[List[int], int]) -> int:
            if torch.jit.isinstance(x, (int, float, str)):
                # We should know that `x` is an `int` here
                z = x + 1
                return z
            else:
                return 100

        self.checkScript(fn, ([1, 2, 3],))
        self.checkScript(fn, (1,))

    def test_union_type_refinement_partial_static_refinement_union_rhs(self):
        def fn(x: Union[List[int], int]) -> int:
            if torch.jit.isinstance(x, Union[int, float, str]):
                # We should know that `x` is an `int` here
                z = x + 1
                return z
            else:
                return 100

        self.checkScript(fn, ([1, 2, 3],))
        self.checkScript(fn, (1,))

    def test_union_type_refinement_internal_declaration(self):
        def fn(flag: bool) -> str:
            x: Union[int, str, None] = None
            if (flag):
                y = "foo"
            else:
                y = 1
            if isinstance(x, str):
                return x
            else:
                return "bar"

        self.checkScript(fn, (True,))
        self.checkScript(fn, (False,))

    def test_union_branching_with_union_return_and_homogenous_types(self):
        def fn(x: int) -> Union[int, str]:
            if x % 2:
                return "foo"
            else:
                return "bar"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (8,))

    def test_union_branching_does_not_autoinfer_undeclared_union(self):
        def fn(x: int) -> str:
            if x % 2:
                y = "foo"
            else:
                y = x
            if isinstance(y, str):
                return y
            else:
                return "bar"

        with self.assertRaisesRegex(RuntimeError, "y is set to type str"
                                    " in the true branch and type int "
                                    "in the false branch"):
            torch.jit.script(fn)

    def test_union_branching_does_not_widen_existing_inferred_type(self):
        def fn(x: int) -> str:
            y = "foo"
            if x % 2:
                y = "bar"
            else:
                y = x
            if isinstance(y, str):
                return y
            else:
                return "baz"

        with self.assertRaisesRegex(RuntimeError, "previously had type "
                                    "str but is now being assigned to a"
                                    " value of type int"):
            torch.jit.script(fn)

    def test_union_schema_matching_on_internal_type(self):
        def fn(x: Union[List[int], Dict[str, int]]) -> int:
            if torch.jit.isinstance(x, List[int]):
                return x[0]
            else:
                return list(x.values())[0]

        self.checkScript(fn, ([1, 2, 3],))
        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))

    def test_union_subtractive_refinement(self):
        def fn(x: Union[List[int], int]) -> int:
            if not isinstance(x, int):
                x.append(1)
                return x[0]
            else:
                return x

        self.checkScript(fn, (1,))
        self.checkScript(fn, ([1, 2, 3],))

    def test_union_subtractive_refinement_with_container(self):
        def fn(x: Union[List[int], int]) -> int:
            if not torch.jit.isinstance(x, List[int]):
                return x
            else:
                x.append(1)
                return x[0]

        self.checkScript(fn, (1,))
        self.checkScript(fn, ([1, 2, 3],))

    def test_union_memory_aliasing(self):
        def fn():
            x : List[torch.Tensor] = []
            z : List[Optional[List[torch.Tensor]]] = []
            z.append(x)
            x_alias = z[0]
            if torch.jit.isinstance(x_alias, List[torch.Tensor]):
                x_alias.append(torch.tensor(3))
            return x

        self.checkScript(fn, ())

    def test_union_serialization_preserves_type_annotations(self):
        # This function will fail after being torch.jit.save'd and
        # torch.jit.load'd if the type annotations aren't preserved
        # for Union during serialization. We need the `Union[str, int]`
        # annotation to make sure that `y` is typed as a Union instead
        # of as a str in one branch and an int in the other
        def fn(x: int) -> str:
            if x % 2:
                y: Union[str, int] = "bar"
            else:
                y: Union[str, int] = x
            if isinstance(y, str):
                return y
            else:
                return "baz"

        self.checkScript(fn, (1,))
        self.checkScript(fn, (8,))

    def _assert_passes(self, template: str, ann: str, lhs: str):
        code = template.format(ann=ann, lhs=lhs)
        self.checkScript(code, (), name="fn")

    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
        code = template.format(ann=ann, lhs=lhs)
        with self.assertRaisesRegex(RuntimeError, msg):
            cu = torch.jit.CompilationUnit(code, _frames_up=1)
            string_frontend = getattr(cu, "fn")    # noqa: B009

    def test_union_with_list_assignment(self):
        template = dedent('''
            def fn():
                x: {ann} = {lhs}
                if torch.jit.isinstance(x, List[torch.Tensor]):
                    x.append(torch.tensor(3))
                return x
        ''')

        lhs = {"list_literal_empty" : "[]",

               "list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",

               "list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",

               "list_literal_of_mixed" : "[torch.arange(5), 1]",

               "list_comprehension_of_tensor" :
               "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",

               "list_comprehension_of_str" :
               "[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",

               "list_comprehension_of_mixed" :
               "[torch.add(1, x) for x in [torch.arange(5), 1]]"}

        """
        Union[List[str], List[torch.Tensor]]
        """
        self._assert_raises(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_literal_empty"],
                            "there are multiple possible List type "
                            "candidates in the Union annotation")

        self._assert_passes(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_literal_of_tensor"])

        self._assert_passes(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_literal_of_str"])

        self._assert_raises(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_literal_of_mixed"],
                            "none of those types match the types of the"
                            " given list elements")

        self._assert_passes(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_comprehension_of_tensor"])

        self._assert_passes(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_comprehension_of_str"])

        # TODO: Support mixed list comprehensions
        self._assert_raises(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["list_comprehension_of_mixed"],
                            "Arguments for call are not valid")

        """
        Union[int, torch.Tensor]
        """
        self._assert_raises(template,
                            "Union[int, torch.Tensor]",
                            lhs["list_literal_empty"],
                            "Expected an Union type annotation with an "
                            "inner List type")

        self._assert_raises(template, "Union[int, torch.Tensor]",
                            lhs["list_literal_of_tensor"],
                            "Expected an Union type annotation with an "
                            "inner List type")

        self._assert_raises(template, "Union[int, torch.Tensor]",
                            lhs["list_comprehension_of_tensor"],
                            "Expected an Union type annotation with an "
                            "inner List type")

        """
        Union[List[torch.Tensor], int]
        """
        self._assert_passes(template,
                            "Union[List[torch.Tensor], int]",
                            lhs["list_literal_empty"])

        self._assert_passes(template,
                            "Union[List[torch.Tensor], int]",
                            lhs["list_literal_of_tensor"])

        self._assert_raises(template, "Union[List[torch.Tensor], int]",
                            lhs["list_literal_of_str"],
                            r"List type annotation `List\[Tensor\]` did "
                            "not match the types of the given list "
                            "elements")

        self._assert_raises(template, "Union[List[torch.Tensor], int]",
                            lhs["list_literal_of_mixed"],
                            r"List type annotation `List\[Tensor\]` did "
                            "not match the types of the given list "
                            "elements")

        self._assert_passes(template,
                            "Union[List[torch.Tensor], int]",
                            lhs["list_comprehension_of_tensor"])

        self._assert_raises(template,
                            "Union[List[torch.Tensor], int]",
                            lhs["list_comprehension_of_str"],
                            r"List type annotation `List\[Tensor\]` did "
                            "not match the types of the given list "
                            "elements")

        # TODO(@ansley): Support mixed list comprehensions
        self._assert_raises(template,
                            "Union[List[torch.Tensor], int]",
                            lhs["list_comprehension_of_mixed"],
                            "Arguments for call are not valid")

    def test_union_with_dict_assignment(self):
        template = dedent('''
            def fn():
                x: {ann} = {lhs}
                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
                    x["foo"] = torch.tensor(3)
                return x
        ''')

        lhs = {"dict_literal_empty" : "{}",

               "dict_literal_of_str_tensor" :
               "{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",

               "dict_literal_of_str_int" :
               "{\"foo\" : 1, \"bar\" : 2}",

               "dict_literal_of_mixed" :
               "{\"foo\" : torch.arange(3), \"bar\" : 2}",

               "dict_comprehension_of_str_tensor" :
               "{x : torch.add(y, 1) for x, y in \
                    zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",

               "dict_comprehension_of_str_int" :
               "{x : torch.add(y, 1) for x, y in \
                    zip([\"foo\", \"bar\"], [1, 2]}",

               "dict_comprehension_of_mixed" :
               "{x : torch.add(y, 1) for x, y in \
                    zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",

               "dict_keyword" :
               "dict(foo=torch.arange(3), baz=torch.arange(5))",

               "dict_keyword_with_iterable" :
               "dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])",

               "dict_keyword_with_empty_iterable" :
               "dict([])",

               "dict_keyword_with_internal_aggregate_function" :
               "dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])",

               "dict_keyword_with_mapping" :
               "dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})",

               "dict_keyword_with_mapping_and_kwargs" :
               "dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))",

               }

        """
        Union[Dict[str, torch.Tensor], Dict[str, int]]
        """
        self._assert_raises(template,
                            "Union[List[str], List[torch.Tensor]]",
                            lhs["dict_literal_empty"],
                            "Expected an Union type annotation with an "
                            "inner Dict type")

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_literal_of_str_tensor"])

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_literal_of_str_int"])

        self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_literal_of_mixed"],
                            "none of those dict types can hold the "
                            "types of the given keys and values")

        # TODO: String frontend does not support tuple unpacking
        # https://github.com/pytorch/pytorch/issues/64096
        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #              lhs["dict_comprehension_of_str_tensor"])

        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #              lhs["dict_comprehension_of_str_int"])

        # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #              lhs["dict_comprehension_of_mixed"],
        #              "foobar")

        # self._assert_passes(template,
        #                    "Union[Dict[str, torch.Tensor], Dict[str, int]]",
        #                    lhs["dict_keyword_with_internal_aggregate_function"])

        # TODO(@ansley): Follow-up project needed for full type
        # inference with dict keyword (supported for dict comprehension
        # and dict literal already; should not be a blocker for anyone)
        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_keyword"],
                            "full type inference is not yet supported")

        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_keyword_with_iterable"],
                            "full type inference is not yet supported")

        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_keyword_with_empty_iterable"],
                            "full type inference is not yet supported")

        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_keyword_with_mapping"],
                            "full type inference is not yet supported")

        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
                            lhs["dict_keyword_with_mapping_and_kwargs"],
                            "full type inference is not yet supported")

        """
        Union[int, torch.Tensor]
        """
        self._assert_raises(template,
                            "Union[int, torch.Tensor]",
                            lhs["dict_literal_empty"],
                            "Expected an Union type annotation with "
                            "an inner Dict type")

        self._assert_raises(template,
                            "Union[int, torch.Tensor]",
                            lhs["dict_literal_of_str_tensor"],
                            "Expected an Union type annotation with "
                            "an inner Dict type")

        # See above--string frontend does not support tuple unpacking
        # self._assert_raises(template, "Union[int, torch.Tensor]",
        #              lhs["dict_comprehension_of_tensor"],
        #              "foobar")

        """
        Union[Dict[str, torch.Tensor], int]
        """
        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_literal_empty"])

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_literal_of_str_tensor"])

        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_literal_of_str_int"],
                            "Type annotation was inferred to be "
                            r"`Dict\[str, Tensor\]`, but the type of "
                            "values given by the dict literal is")

        self._assert_raises(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_literal_of_mixed"],
                            "Type annotation was inferred to be "
                            r"`Dict\[str, Tensor\]`, but the type of "
                            "values given by the dict literal is")

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_keyword"])

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_keyword_with_iterable"])

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_keyword_with_empty_iterable"])

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_keyword_with_mapping"])

        self._assert_passes(template,
                            "Union[Dict[str, torch.Tensor], int]",
                            lhs["dict_keyword_with_mapping_and_kwargs"])

        # See above--string frontend does not support tuple unpacking
        # self._assert_passes(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_keyword_with_internal_aggregate_function"])
        #
        # self._assert_passes(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_comprehension_of_str_tensor"])

        # self._assert_raises(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_comprehension_of_str_int"],
        #                    "foobar")

        # self._assert_raises(template,
        #                    "Union[Dict[str, torch.Tensor], int]",
        #                    lhs["dict_comprehension_of_mixed"],
        #                    "foobar")
