# Owner(s): ["oncall: jit"]

import os
import sys

import torch
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import IS_WINDOWS
from collections import namedtuple
from typing import List, Tuple, Optional, Dict

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )

class TestTyping(JitTestCase):
    def test_dict_in_not_in(self):
        def test_in_dict(x):
            # type: (Dict[str, int]) -> bool
            return 'hi' in x

        self.checkScript(test_in_dict, ({'hi': 2, 'bye': 3},))
        self.checkScript(test_in_dict, ({'bye': 3},))

        # Check evaluation order
        @torch.jit.script
        def a():
            print("a")
            return 3

        @torch.jit.script
        def b():
            print("b")
            return {3: 2, 4: 1}

        @torch.jit.script
        def fn():
            return a() in b()

        with self.capture_stdout() as captured:
            self.assertTrue(fn())
        if not IS_WINDOWS:
            # no stdout capturing on windows
            self.assertEqual(captured[0], "a\nb\n")

        def test_not_in_dict(a):
            # type: (Dict[str, int]) -> bool
            if "hello" not in a:
                return False
            else:
                return True

        self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2}, ))
        self.checkScript(test_not_in_dict, ({"world": 2}, ))

        def test_dict_tensor_key(a, t):
            # type: (Dict[Tensor, int], Tensor) -> bool
            if t in a:
                return True
            else:
                return False

        inp1 = torch.tensor(3)
        inp2 = torch.tensor(5)
        dict_a = {inp1: 1, inp2: 3}
        self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(4)))
        self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(3)))
        self.checkScript(test_dict_tensor_key, (dict_a, inp1))
        self.checkScript(test_dict_tensor_key, (dict_a, inp2))

    def test_list_type_refinement_annotation_element_mismatch(self):
        def fn():
            l: List[int] = [1, 2, "foo", 3]
            return l

        with self.assertRaisesRegex(RuntimeError, "List type annotation"
                                    r" `List\[int\]` did not match the "
                                    "types of the given list elements"):
            torch.jit.script(fn)

    def test_dict_type_refinement_annotation_key_mismatch(self):
        def fn():
            l1 = [1, 2, "foo", 3]
            l2 = ["foo", "bar", "baz", "qux"]
            d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
            return d

        with self.assertRaisesRegex(RuntimeError, "Dicts may only "
                                    "contain homogeneous keys, but the "
                                    "type of the first generated key "
                                    r"was Union\[int, str\]"):
            torch.jit.script(fn)

    def test_dict_type_refinement_annotation_value_mismatch(self):
        def fn():
            l1 = ["foo", "bar", "baz", "qux"]
            l2 = [1, 2, "foo", 3]
            d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
            return d

        with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
                                    r" `Dict\[str, int\]` did not match"
                                    " the type of an actual value type"
                                    r" `Union\[int, str\]`"):
            torch.jit.script(fn)

    def test_dict_invalid_annotations(self):
        # Check for invalid value type annotation
        def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]):
            return
        with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
            torch.jit.script(wrong_value_type)

        # Check for invalid key type annotation
        def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]):
            return
        with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
            torch.jit.script(wrong_key_type)

        # Check for invalid key and value type annotation
        def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]):
            return
        with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
            torch.jit.script(wrong_key_value_type)

    def test_tuple_specialization(self):
        @torch.jit.script
        def f(t, s):
            # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor
            x, t2 = t
            _, y = t2
            return x + y

        t = torch.randn(2, 2), (1, torch.randn(2, 2)),
        f(t, "hi")
        graph = f.graph_for(t, "hi")
        input_types = list(next(graph.inputs()).type().elements())
        w = input_types[0]
        self.assertEqual(input_types[0].kind(), 'TensorType')
        self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType')

    def test_tuple_io(self):
        def stuff(x):
            # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
            a, b = x
            return b, a

        a = (torch.rand(3), torch.rand(3))
        self.checkScript(stuff, (a,))

    def test_tuple_keyword(self):
        def bar():
            f = tuple((1, 2))  # noqa: C409
            return f

        self.checkScript(bar, ())

        def foo():
            return tuple(1, 2)

        self.checkScriptRaisesRegex(foo, (), Exception,
                                    "1 argument")

        def cant_infer_size():
            return tuple([1, 2, 3])  # noqa: C409

        with self.assertRaisesRegex(Exception, "cannot statically infer the expected"):
            torch.jit.script(cant_infer_size)

    def test_tuple_create_return(self):
        def stuff2(x):
            # type: (int) -> Tuple[Tensor, Tensor]
            a = (torch.ones(x), torch.zeros(x))
            return a
        self.checkScript(stuff2, (3,))

    def test_list_io(self):
        def stuff3(x):
            # type: (List[int]) -> Tuple[Tensor, List[int]]
            return torch.ones(x), x
        self.checkScript(stuff3, ([3, 2],))

    def test_bool_list_io(self):
        @torch.jit.script
        def stuff4(x):
            # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
            return x, [True, False], [[True]]

        li_1, li_2, li_3 = stuff4([True])
        li_3 = li_3[0]
        for li in [li_1, li_2, li_3]:
            self.assertTrue(type(li[0]) == type(True))

    def test_nested_list(self):
        def foo(z):
            # type: (Tuple[int, List[List[int]]]) -> int
            x, y = z
            return y[0][1]
        self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))

    def test_list_sum(self):
        def fn(x: List[int]) -> int:
            return sum(x)

        def fn1(x: List[float]):
            return sum(x)

        def fn2(x: List[bool]):
            return sum(x)

        self.checkScript(fn, ([1, 2, 3], ))
        self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
        self.checkScript(fn1, ([1, 2.8, 3], ))
        self.checkScript(fn2, ([True, False, False], ))
        self.checkScript(fn2, ([False, False, False], ))
        self.checkScript(fn2, ([0, 1, 1, 0], ))

    def test_list_unification(self):
        def fn():
            return [1, None, 2]

        def fn2(x):
            return [torch.ones(2, 2), None, x]

        self.checkScript(fn, [])
        self.checkScript(fn2, (torch.ones(2, 2),))

    # to avoid defining sum_list in multiple tests
    def get_sum_list_fn(self):
        def sum_list(a):
            # type: (List[int]) -> int
            sum = 0
            for i in a:
                sum += i

            return sum

        return sum_list

    def test_sum_list_diff_elms(self):
        self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))

    def test_sum_list_empty(self):
        self.checkScript(self.get_sum_list_fn(), ([],))

    def test_sum_list_one(self):
        self.checkScript(self.get_sum_list_fn(), ([1],))

    def test_sum_list_literal(self):

        def sum_list():
            # type: () -> int
            sum = 0
            for i in [1, 2, 3, 4, 5]:
                sum += i

            return sum

        self.checkScript(sum_list, ())

    def test_sum_list_wrong_type(self):

        with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
            @torch.jit.script
            def sum_list(a):
                # type: (int) -> int
                sum = 0
                for i in a:  # noqa: T484
                    sum += i

                return sum

            sum_list(1)

    def test_list_iterables(self):
        with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
            cu = torch.jit.CompilationUnit('''
            def list_iterables(x):
                for i, j in [2, 3, 4], [5, 6, 7]:
                    x += i
                    x += j
                return x
            ''')

    def test_for_in_string(self):
        def test_strings(x):
            # type: (str) -> str
            reverse = ""
            for c in x:
                reverse = c + reverse
            return reverse

        self.checkScript(test_strings, ("hello",))
        self.checkScript(test_strings, ("",))

        def test_list_strings(x):
            # type: (List[str]) -> str
            result = ""
            for sub_str in x:
                result += sub_str
            return result

        self.checkScript(test_list_strings, (["hello", "world"],))
        self.checkScript(test_list_strings, (["hello", " ", "world", ""],))

    def test_for_in_dict(self):
        def test_dicts(x):
            # type: (Dict[str, int]) -> int
            sum = 0
            for key in x:
                sum += x[key]
            return sum

        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))

        def test_dict_keys_values(x):
            # type: (Dict[str, int]) -> Tuple[str, int]
            key_str = ""
            sum = 0
            for key in x.keys():
                key_str += key
            for val in x.values():
                sum += val
            return key_str, sum

        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))

    def test_for_tuple_unpack(self):
        def for_tuple_unpack(x, y):
            for i, j in [[3, 4], [5, 6], [7, 8]]:
                x += i
                y += j
            return x, y

        self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))

        def nested_tuple_unpack(x, y):
            # type: (List[int], List[int]) -> int
            sum = 0
            for i, (j, k), v in zip(x, enumerate(x), y):
                sum += i + j + k + v
            return sum

        self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))

    def test_dict_comprehension(self):
        def fn():
            return {i : chr(i + 65) for i in range(4)}
        self.checkScript(fn, ())

    def test_dict_comprehension_with_type_annotation(self):
        def fn():
            d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
            return d
        self.checkScript(fn, ())

        with self.assertRaisesRegex(RuntimeError, ""):
            with self.assertRaisesRegex(AssertionError, "Expected Dict "
                                        "type annotation for dict "
                                        "comprehension, found "
                                        "Tuple[int, str]"):
                @torch.jit.script
                def fn():
                    d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
                    return d

    def test_dict_comprehension_scope(self):
        def comprehension_can_access_outer_scope_variables():
            lst = ["foo", "bar", "baz"]
            return {l : len(l) for l in lst}

        self.checkScript(comprehension_can_access_outer_scope_variables, ())

        with self.assertRaisesRegex(RuntimeError, "undefined value i"):
            @torch.jit.script
            def outer_scope_cannot_access_comprehension_variables():
                d = {i : chr(i + 65) for i in range(4)}
                i = i + 1

    def test_for_tuple_assign(self):
        def test_simple_assign(x):
            # type: (Tuple[int, float]) -> float
            sum = 0.0
            for a in x:
                sum += float(a)
            return sum

        self.checkScript(test_simple_assign, ((1, 2.5),))

        def test_tuple_assign(x):
            # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
            sum = 0
            for a in x:
                sum += a[0]
                sum += a[1]
            return sum

        self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))

        def test_single_starred_lhs(self):
            with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
                                                      ' of another non-starred expression'):
                cu = torch.jit.CompilationUnit('''
                def single_starred_lhs(x):
                    a = (x, x, x)
                    *b, = a
                    return b
                ''')

    def test_singleton_tuple_unpack(self):
        def foo(a):
            b, = (a,)
            return b + 1
        self.checkScript(foo, (torch.rand(3),))

    def test_tuple_assignments(self):
        def var_tuple_assign(x, y):
            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
            (a, b), c = x, y
            return a + b + c

        tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
        self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))

        def nested_tuple_assign(x, y, z):
            # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
            a, (b, (c, d)), (e, f) = x, y, z
            return a + b + c + d + e + f

        self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))

        def subscript_tuple_assign(a, x, i):
            # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
            a[i], (x[i], b) = 1, (2, 3)
            return a[i] + 1, x + 5, b

        self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))

        def star_tuple_assign():
            # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
            a, (b, *c), *d = 1, (2, 3, 4), 5, 6
            return a, b, c, d

        self.checkScript(star_tuple_assign, ())

        def subscript_tuple_augmented_assign(a):
            # type: (Tuple[int, int]) -> Tuple[int, int]
            a[0] += 1
            return a

        with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
            scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)

    def test_multiple_assign(self):
        def test():
            a = b, c = d, f = (1, 1)

            # side effect
            ten = torch.tensor(1)
            ten1 = ten2 = ten.add_(1)

            # ordering
            x = 1
            y = 3
            x, y = y, x + y

            return a, b, c, d, f, ten, ten1, ten2, x, y

        self.checkScript(test, ())

    def test_opt_opt_refinement(self):
        @torch.jit.script
        def test_unify(weight, bias):
            # type: (Optional[int], Optional[int]) -> Optional[int]
            if weight is not None:
                opt = None
            else:
                if bias is not None:
                    opt = 1
                else:
                    opt = None

            return opt

    def test_optional_refinement(self):
        @torch.jit.script
        def test_if_none_assignment(x):
            # type: (Optional[int]) -> int
            if x is None:
                x = 1
            return x + 1

        self.assertEqual(test_if_none_assignment(1), 2)

    def test_optional_conversion(self):
        @torch.jit.script
        def other_fn(x=None):
            # type: (Optional[int]) -> int
            return torch.jit._unwrap_optional(x)


        @torch.jit.script
        def fn(x):
            # type: (int) -> int
            return other_fn(x)

        self.assertEqual(fn(2), 2)

        @torch.jit.script
        def unify_to_optional(x):
            # type: (bool) -> Optional[int]
            if x:
                a = None
            else:
                a = 2
            return a

        self.assertEqual(unify_to_optional(True), None)
        self.assertEqual(unify_to_optional(False), 2)

        @torch.jit.script
        def opt_list(x):
            # type: (Optional[List[float]]) -> int
            return 2

        @torch.jit.script
        def broadcast_opt_list(x):
            # type: (Optional[BroadcastingList2[float]]) -> int
            return 2

        @torch.jit.script
        def opt_list_tuple_caller(x):
            # type: (Tuple[float, float]) -> int
            return opt_list(x) + broadcast_opt_list(x)

        self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)

    def test_optional_tuple(self):
        def fn(x=None):
            # type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
            if x is None:
                new_x = (1, 2)
            else:
                new_x = x
            return new_x

        self.checkScript(fn, ((3, 4),))
        self.checkScript(fn, ())

    def test_namedtuple_redefine(self):
        global _1, _2
        _1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
        _2 = namedtuple('GoogLeNetOutputs', ['different'])

        with self.assertRaisesRegex(RuntimeError, r'redefine'):
            @torch.jit.script
            def foo(x, y):
                # type: (_1, _2) -> _1
                return x

    def test_namedtuple_py2(self):
        global _GoogLeNetOutputs  # see [local resolution in python]
        _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])

        @torch.jit.script
        def foo(x):
            # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
            return x

        vals = torch.rand(3), torch.rand(4), torch.rand(5)
        out = foo(_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2]))
        self.assertEqual(out.logits, vals[0])
        self.assertEqual(out.aux_logits2, vals[1])
        self.assertEqual(out.aux_logits1, vals[2])

    def test_namedtuple_good_error(self):
        global _GoogLeNetOutputs  # see [local resolution in python]
        _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])

        @torch.jit.script
        def foo(x):
            # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
            return x

        with self.assertRaisesRegex(RuntimeError,
                                    r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'):
            out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5"))
