import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
    token_start_pos,
    extract_ctes,
    extract_column_names as _extract_column_names,
)


def extract_column_names(sql):
    p = parse(sql)[0]
    return _extract_column_names(p)


def test_token_str_pos():
    sql = "SELECT * FROM xxx"
    p = parse(sql)[0]
    idx = p.token_index(p.tokens[-1])
    assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")

    sql = "SELECT * FROM \nxxx"
    p = parse(sql)[0]
    idx = p.token_index(p.tokens[-1])
    assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")


def test_single_column_name_extraction():
    sql = "SELECT abc FROM xxx"
    assert extract_column_names(sql) == ("abc",)


def test_aliased_single_column_name_extraction():
    sql = "SELECT abc def FROM xxx"
    assert extract_column_names(sql) == ("def",)


def test_aliased_expression_name_extraction():
    sql = "SELECT 99 abc FROM xxx"
    assert extract_column_names(sql) == ("abc",)


def test_multiple_column_name_extraction():
    sql = "SELECT abc, def FROM xxx"
    assert extract_column_names(sql) == ("abc", "def")


def test_missing_column_name_handled_gracefully():
    sql = "SELECT abc, 99 FROM xxx"
    assert extract_column_names(sql) == ("abc",)

    sql = "SELECT abc, 99, def FROM xxx"
    assert extract_column_names(sql) == ("abc", "def")


def test_aliased_multiple_column_name_extraction():
    sql = "SELECT abc def, ghi jkl FROM xxx"
    assert extract_column_names(sql) == ("def", "jkl")


def test_table_qualified_column_name_extraction():
    sql = "SELECT abc.def, ghi.jkl FROM xxx"
    assert extract_column_names(sql) == ("def", "jkl")


@pytest.mark.parametrize(
    "sql",
    [
        "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
        "DELETE FROM foo WHERE x > y RETURNING x, y",
        "UPDATE foo SET x = 9 RETURNING x, y",
    ],
)
def test_extract_column_names_from_returning_clause(sql):
    assert extract_column_names(sql) == ("x", "y")


def test_simple_cte_extraction():
    sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
    start_pos = len("WITH a AS ")
    stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
    ctes, remainder = extract_ctes(sql)

    assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
    assert remainder.strip() == "SELECT * FROM a"


def test_cte_extraction_around_comments():
    sql = """--blah blah blah
            WITH a AS (SELECT abc def FROM x)
            SELECT * FROM a"""
    start_pos = len(
        """--blah blah blah
            WITH a AS """
    )
    stop_pos = len(
        """--blah blah blah
            WITH a AS (SELECT abc def FROM x)"""
    )

    ctes, remainder = extract_ctes(sql)
    assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
    assert remainder.strip() == "SELECT * FROM a"


def test_multiple_cte_extraction():
    sql = """WITH
            x AS (SELECT abc, def FROM x),
            y AS (SELECT ghi, jkl FROM y)
            SELECT * FROM a, b"""

    start1 = len(
        """WITH
            x AS """
    )

    stop1 = len(
        """WITH
            x AS (SELECT abc, def FROM x)"""
    )

    start2 = len(
        """WITH
            x AS (SELECT abc, def FROM x),
            y AS """
    )

    stop2 = len(
        """WITH
            x AS (SELECT abc, def FROM x),
            y AS (SELECT ghi, jkl FROM y)"""
    )

    ctes, remainder = extract_ctes(sql)
    assert tuple(ctes) == (
        ("x", ("abc", "def"), start1, stop1),
        ("y", ("ghi", "jkl"), start2, stop2),
    )
