#------------------------------------------------------------------------------
# Copyright (c) 2018-2024, Nucleic Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
#------------------------------------------------------------------------------
import os
import sys
import ast
import enaml
import pytest
import traceback
from textwrap import dedent

from enaml.compat import PY310, PY311


def validate_ast(py_node, enaml_node, dump_ast=False, offset=0):
    """Validate each node of an ast against another ast.

    Typically used to compare an AST generated by the Python paser and one
    generated by the enaml parser.

    """
    if dump_ast:
        print('Python node:\n', ast.dump(py_node))
        print('Enaml node:\n', ast.dump(enaml_node))
    assert type(py_node) == type(enaml_node)
    if isinstance(py_node, ast.AST):
        for name, field in ast.iter_fields(py_node):
            if name == 'ctx':
                assert type(field) == type(getattr(enaml_node, name))
            else:
                field2 = getattr(enaml_node, name, None)
                print('    '*offset, 'Validating:', name)
                validate_ast(field, field2, offset=offset+1)
    elif isinstance(py_node, list):
        if len(py_node) != len(enaml_node):
            return False
        for i, n1 in enumerate(py_node):
            print('    '*offset, 'Validating', i+1, 'th element')
            validate_ast(n1, enaml_node[i], offset=offset+1)
    else:
        assert py_node == enaml_node


def test_syntax_error_traceback_correct_path(tmpdir):
    """ Test that a syntax error retains the path to the file

    """
    test_module_path = os.path.join(tmpdir.strpath, 'view.enaml')

    with open(os.path.join(tmpdir.strpath, 'test_main.enaml'), 'w') as f:
        f.write(dedent("""
        from enaml.widgets.api import Window, Container, Label
        from view import CustomView

        enamldef MyWindow(Window): main:
            CustomView:
                pass

        """))

    with open(test_module_path, 'w') as f:
        f.write(dedent("""
        from enaml.widgets.api import Container, Label

        enamldef CustomLabel(Container):
            Label # : missing intentionally
                text = "Hello world"

        """))

    try:
        sys.path.append(tmpdir.strpath)
        with enaml.imports():
            from test_main import MyWindow
        assert False, "Should raise a syntax error"
    except Exception as e:
        tb = traceback.format_exc()
        print(tb)
        lines = tb.strip().split("\n")
        line = '\n'.join(lines[-4:])
        expected = 'File "{}", line 5'.format(test_module_path)
        assert expected in line
    finally:
        sys.path.remove(tmpdir.strpath)


def test_syntax_error_traceback_show_line(tmpdir):
    """ Test that a syntax error retains the path to the file

    """
    test_module_path = os.path.join(tmpdir.strpath, 'test_syntax.enaml')

    with open(test_module_path, 'w') as f:
        f.write(dedent("""
        from enaml.widgets.api import Container, Label

        enamldef CustomLabel(Container):
            Label # : missing intentionally
                text = "Hello world"
        """))

    try:
        sys.path.append(tmpdir.strpath)
        with enaml.imports():
            from test_syntax import CustomLabel
        assert False, "Should raise a syntax error"
    except Exception as e:
        tb = traceback.format_exc()
        print(tb)
        lines = tb.strip().split("\n")
        line = '\n'.join(lines[-4:])

        expected = 'Label # : missing intentionally'
        assert expected in line
    finally:
        sys.path.remove(tmpdir.strpath)


INDENTATION_TESTS = {
     "enamldef-block": (
        """
        from enaml.widgets.api import Window, Container, Label

        enamldef MainWindow(Window):
        attr x = 1
        """,
        "attr x = 1",
    ),
    "childdef-block": (
        """
        from enaml.widgets.api import Window, Container, Label

        enamldef MainWindow(Window):
            Container:
            Label: # no indent
                text = "Hello world"
        """,
        "Label: # no indent",
    ),
    "childdef-indent-mismatch": (
        """
        from enaml.widgets.api import Window, Container, Label

        enamldef MainWindow(Window):
            Container:
                Label:
                    text = "Hello world"
                 Label: # indent mismatch
                    text = "Hello world"
        """,
        "Label: # indent mismatch",
    ),
    "childdef-attr": (
        """
        from enaml.widgets.api import Window, Container, Label

        enamldef MainWindow(Window):
            Container:
                Label:
                text = 'Hello world'
        """,
        "text = 'Hello world'"
    ),
    "if-block": (
        """
        from enaml.widgets.api import Window

        enamldef MainWindow(Window):
            func go():
                if True:
                x = 1
                else:
                    x = 0
        """,
        "x = 1",
    ),
    "for-block": (
        """
        from enaml.widgets.api import Window

        enamldef MainWindow(Window):
            func go():
                x = 0
                for i in range(4):
                x += 1
        """,
        "x += 1"
    ),
    "try-block": (
        """
        from enaml.widgets.api import Window
        enamldef MainWindow(Window):
            func go():
                try:
                x = 1/0
                except Exception as e:
                    print(e)
        """,
        "x = 1/0"
    ),
    "except-block": (
        """
        from enaml.widgets.api import Window

        enamldef MainWindow(Window):
            func go():
                try:
                    x = 0
                except Exception as e:
                print(e)
        """,
        "print(e)"
    ),
    "finally-block": (
        """
        from enaml.widgets.api import Window

        enamldef MainWindow(Window):
            func go():
                try:
                    x = 0
                finally:
                x = 2
                return 3
        """,
        "x = 2"
    ),
    "class": (
        """
        from enaml.widgets.api import Window, Container, Label

        class Foo:
            x = 1
           def add():
               self.x += 1
        """,
        "def add()",
    ),
}

@pytest.mark.parametrize("label", INDENTATION_TESTS.keys())
def test_indent_error_traceback_show_line(tmpdir, label):
    """ Test that a syntax error retains the path to the file

    """
    test_module_path = os.path.join(tmpdir.strpath, f'test_indent_{label}.enaml')
    source, expected= INDENTATION_TESTS[label]
    with open(test_module_path, 'w') as f:
        f.write(dedent(source.lstrip("\n")))
    try:
        sys.path.append(tmpdir.strpath)
        with enaml.imports():
            __import__(f"test_indent_{label}")
        assert False, "Should raise a identation error"
    except IndentationError as e:
        tb = traceback.format_exc()
        print(tb)
        lines = tb.strip().split("\n")
        line = '\n'.join(lines[-4:])
        assert expected in line
    finally:
        sys.path.remove(tmpdir.strpath)
