# Owner(s): ["oncall: jit"]

import os
import sys

import torch
from torch.testing import FileCheck
from enum import Enum
from typing import Any, List

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global

if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_jit.py TESTNAME\n\n"
                       "instead.")

class TestEnum(JitTestCase):
    def test_enum_value_types(self):
        class IntEnum(Enum):
            FOO = 1
            BAR = 2

        class FloatEnum(Enum):
            FOO = 1.2
            BAR = 2.3

        class StringEnum(Enum):
            FOO = "foo as in foo bar"
            BAR = "bar as in foo bar"

        make_global(IntEnum, FloatEnum, StringEnum)

        @torch.jit.script
        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
            return (a.name, b.name, c.name)

        FileCheck() \
            .check("IntEnum") \
            .check("FloatEnum") \
            .check("StringEnum") \
            .run(str(supported_enum_types.graph))

        class TensorEnum(Enum):
            FOO = torch.tensor(0)
            BAR = torch.tensor(1)

        make_global(TensorEnum)

        def unsupported_enum_types(a: TensorEnum):
            return a.name

        # TODO: rewrite code so that the highlight is not empty.
        with self.assertRaisesRegexWithHighlight(RuntimeError, "Cannot create Enum with value type 'Tensor'", ""):
            torch.jit.script(unsupported_enum_types)

    def test_enum_comp(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        FileCheck().check("aten::eq").run(str(enum_comp.graph))

        self.assertEqual(enum_comp(Color.RED, Color.RED), True)
        self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)

    def test_enum_comp_diff_classes(self):
        class Foo(Enum):
            ITEM1 = 1
            ITEM2 = 2

        class Bar(Enum):
            ITEM1 = 1
            ITEM2 = 2

        make_global(Foo, Bar)

        @torch.jit.script
        def enum_comp(x: Foo) -> bool:
            return x == Bar.ITEM1

        FileCheck() \
            .check("prim::Constant") \
            .check_same("Bar.ITEM1") \
            .check("aten::eq") \
            .run(str(enum_comp.graph))

        self.assertEqual(enum_comp(Foo.ITEM1), False)

    def test_heterogenous_value_type_enum_error(self):
        class Color(Enum):
            RED = 1
            GREEN = "green"

        make_global(Color)

        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        # TODO: rewrite code so that the highlight is not empty.
        with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not unify type list", ""):
            torch.jit.script(enum_comp)

    def test_enum_name(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_name(x: Color) -> str:
            return x.name

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumName") \
            .check_next("return") \
            .run(str(enum_name.graph))

        self.assertEqual(enum_name(Color.RED), Color.RED.name)
        self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)

    def test_enum_value(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_value(x: Color) -> int:
            return x.value

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumValue") \
            .check_next("return") \
            .run(str(enum_value.graph))

        self.assertEqual(enum_value(Color.RED), Color.RED.value)
        self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)

    def test_enum_as_const(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_const(x: Color) -> bool:
            return x == Color.RED

        FileCheck() \
            .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
            .check_next("aten::eq") \
            .check_next("return") \
            .run(str(enum_const.graph))

        self.assertEqual(enum_const(Color.RED), True)
        self.assertEqual(enum_const(Color.GREEN), False)

    def test_non_existent_enum_value(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        def enum_const(x: Color) -> bool:
            if x == Color.PURPLE:
                return True
            else:
                return False

        with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"):
            torch.jit.script(enum_const)

    def test_enum_ivalue_type(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def is_color_enum(x: Any):
            return isinstance(x, Color)

        FileCheck() \
            .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
            .check_next("return") \
            .run(str(is_color_enum.graph))

        self.assertEqual(is_color_enum(Color.RED), True)
        self.assertEqual(is_color_enum(Color.GREEN), True)
        self.assertEqual(is_color_enum(1), False)

    def test_closed_over_enum_constant(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        a = Color

        @torch.jit.script
        def closed_over_aliased_type():
            return a.RED.value

        FileCheck() \
            .check("prim::Constant[value={}]".format(a.RED.value)) \
            .check_next("return") \
            .run(str(closed_over_aliased_type.graph))

        self.assertEqual(closed_over_aliased_type(), Color.RED.value)

        b = Color.RED

        @torch.jit.script
        def closed_over_aliased_value():
            return b.value

        FileCheck() \
            .check("prim::Constant[value={}]".format(b.value)) \
            .check_next("return") \
            .run(str(closed_over_aliased_value.graph))

        self.assertEqual(closed_over_aliased_value(), Color.RED.value)

    def test_enum_as_module_attribute(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        class TestModule(torch.nn.Module):
            def __init__(self, e: Color):
                super(TestModule, self).__init__()
                self.e = e

            def forward(self):
                return self.e.value

        m = TestModule(Color.RED)
        scripted = torch.jit.script(m)

        FileCheck() \
            .check("TestModule") \
            .check_next("Color") \
            .check_same("prim::GetAttr[name=\"e\"]") \
            .check_next("prim::EnumValue") \
            .check_next("return") \
            .run(str(scripted.graph))

        self.assertEqual(scripted(), Color.RED.value)

    def test_string_enum_as_module_attribute(self):
        class Color(Enum):
            RED = "red"
            GREEN = "green"

        class TestModule(torch.nn.Module):
            def __init__(self, e: Color):
                super(TestModule, self).__init__()
                self.e = e

            def forward(self):
                return (self.e.name, self.e.value)

        make_global(Color)
        m = TestModule(Color.RED)
        scripted = torch.jit.script(m)

        self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))

    def test_enum_return(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def return_enum(cond: bool):
            if cond:
                return Color.RED
            else:
                return Color.GREEN

        self.assertEqual(return_enum(True), Color.RED)
        self.assertEqual(return_enum(False), Color.GREEN)

    def test_enum_module_return(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        class TestModule(torch.nn.Module):
            def __init__(self, e: Color):
                super(TestModule, self).__init__()
                self.e = e

            def forward(self):
                return self.e

        make_global(Color)
        m = TestModule(Color.RED)
        scripted = torch.jit.script(m)

        FileCheck() \
            .check("TestModule") \
            .check_next("Color") \
            .check_same("prim::GetAttr[name=\"e\"]") \
            .check_next("return") \
            .run(str(scripted.graph))

        self.assertEqual(scripted(), Color.RED)


    def test_enum_iterate(self):
        class Color(Enum):
            RED = 1
            GREEN = 2
            BLUE = 3

        def iterate_enum(x: Color):
            res: List[int] = []
            for e in Color:
                if e != x:
                    res.append(e.value)
            return res

        make_global(Color)
        scripted = torch.jit.script(iterate_enum)

        FileCheck() \
            .check("Enum<__torch__.jit.test_enum.Color>[]") \
            .check_same("Color.RED") \
            .check_same("Color.GREEN") \
            .check_same("Color.BLUE") \
            .run(str(scripted.graph))

        # PURPLE always appears last because we follow Python's Enum definition order.
        self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
        self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value])

    # Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
    def test_enum_explicit_script(self):

        @torch.jit.script
        class Color(Enum):
            RED = 1
            GREEN = 2

        torch.jit.script(Color)
