# Owner(s): ["oncall: jit"]

import os
import sys
import warnings

import torch
from typing import List, Dict, Optional

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_jit.py TESTNAME\n\n"
                       "instead.")

class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):

    # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
    # reassigning a non-empty Tuple to an attribute previously typed
    # as containing an empty Tuple SHOULD fail. See note in `_check.py`

    def test_annotated_falsy_base_type(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x: int = 0

            def forward(self, x: int):
                self.x = x
                return 1

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), (1,))
        assert(len(w) == 0)

    def test_annotated_nonempty_container(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x: List[int] = [1, 2, 3]

            def forward(self, x: List[int]):
                self.x = x
                return 1

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert(len(w) == 0)

    def test_annotated_empty_tensor(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.x: torch.Tensor = torch.empty(0)

            def forward(self, x: torch.Tensor):
                self.x = x
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), (torch.rand(2, 3),))
        assert(len(w) == 0)

    def test_annotated_with_jit_attribute(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.x = torch.jit.Attribute([], List[int])

            def forward(self, x: List[int]):
                self.x = x
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert(len(w) == 0)

    def test_annotated_class_level_annotation_only(self):
        class M(torch.nn.Module):

            x: List[int]

            def __init__(self):
                super().__init__()
                self.x = []

            def forward(self, y: List[int]):
                self.x = y
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert(len(w) == 0)


    def test_annotated_class_level_annotation_and_init_annotation(self):
        class M(torch.nn.Module):

            x: List[int]

            def __init__(self):
                super().__init__()
                self.x: List[int] = []

            def forward(self, y: List[int]):
                self.x = y
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert(len(w) == 0)

    def test_annotated_class_level_jit_annotation(self):
        class M(torch.nn.Module):

            x: List[int]

            def __init__(self):
                super().__init__()
                self.x: List[int] = torch.jit.annotate(List[int], [])

            def forward(self, y: List[int]):
                self.x = y
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert(len(w) == 0)

    def test_annotated_empty_list(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x: List[int] = []

            def forward(self, x: List[int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Tried to set nonexistent attribute",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())

    def test_annotated_empty_dict(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x: Dict[str, int] = {}

            def forward(self, x: Dict[str, int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Tried to set nonexistent attribute",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())

    def test_annotated_empty_optional(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x: Optional[str] = None

            def forward(self, x: Optional[str]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Wrong type for attribute assignment",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())

    def test_annotated_with_jit_empty_list(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x = torch.jit.annotate(List[int], [])

            def forward(self, x: List[int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Tried to set nonexistent attribute",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())

    def test_annotated_with_jit_empty_dict(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x = torch.jit.annotate(Dict[str, int], {})

            def forward(self, x: Dict[str, int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Tried to set nonexistent attribute",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())

    def test_annotated_with_jit_empty_optional(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x = torch.jit.annotate(Optional[str], None)

            def forward(self, x: Optional[str]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Wrong type for attribute assignment",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())

    def test_annotated_with_torch_jit_import(self):
        from torch import jit

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.x = jit.annotate(Optional[str], None)

            def forward(self, x: Optional[str]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(RuntimeError,
                                                 "Wrong type for attribute assignment",
                                                 "self.x = x"):
            with self.assertWarnsRegex(UserWarning, "doesn't support "
                                       "instance-level annotations on "
                                       "empty non-base types"):
                torch.jit.script(M())
