from __future__ import annotations

from typing import Any

import sqlparse
from sqlparse.sql import Comparison, Identifier, Where, Token
from .parseutils import last_word, extract_tables, find_prev_keyword
from .special.main import parse_special_command


def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]]:
    """Takes the full_text that is typed so far and also the text before the
    cursor to suggest completion type and scope.

    Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
    A scope for a column category will be a list of tables.
    """

    word_before_cursor = last_word(text_before_cursor, include="many_punctuations")

    identifier: Identifier | None = None

    # here should be removed once sqlparse has been fixed
    try:
        # If we've partially typed a word then word_before_cursor won't be an empty
        # string. In that case we want to remove the partially typed string before
        # sending it to the sqlparser. Otherwise the last token will always be the
        # partially typed string which renders the smart completion useless because
        # it will always return the list of keywords as completion.
        if word_before_cursor:
            if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
                parsed = sqlparse.parse(text_before_cursor)
            else:
                parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])

                # word_before_cursor may include a schema qualification, like
                # "schema_name.partial_name" or "schema_name.", so parse it
                # separately
                p = sqlparse.parse(word_before_cursor)[0]

                if p.tokens and isinstance(p.tokens[0], Identifier):
                    identifier = p.tokens[0]
        else:
            parsed = sqlparse.parse(text_before_cursor)
    except (TypeError, AttributeError):
        return [{"type": "keyword"}]

    if len(parsed) > 1:
        # Multiple statements being edited -- isolate the current one by
        # cumulatively summing statement lengths to find the one that bounds the
        # current position
        current_pos = len(text_before_cursor)
        stmt_start, stmt_end = 0, 0

        for statement in parsed:
            stmt_len = len(str(statement))
            stmt_start, stmt_end = stmt_end, stmt_end + stmt_len

            if stmt_end >= current_pos:
                text_before_cursor = full_text[stmt_start:current_pos]
                full_text = full_text[stmt_start:]
                break

    elif parsed:
        # A single statement
        statement = parsed[0]
    else:
        # The empty string
        statement = None

    # Check for special commands and handle those separately
    if statement:
        # Be careful here because trivial whitespace is parsed as a statement,
        # but the statement won't have a first token
        tok1 = statement.token_first()
        if tok1 and tok1.value.startswith("."):
            return suggest_special(text_before_cursor)
        elif tok1 and tok1.value.startswith("\\"):
            return suggest_special(text_before_cursor)
        elif tok1 and tok1.value.startswith("source"):
            return suggest_special(text_before_cursor)
        elif text_before_cursor and text_before_cursor.startswith(".open "):
            return suggest_special(text_before_cursor)

    last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""

    return suggest_based_on_last_token(last_token, text_before_cursor, full_text, identifier)


def suggest_special(text: str) -> list[dict[str, Any]]:
    text = text.lstrip()
    cmd, _, arg = parse_special_command(text)

    if cmd == text:
        # Trying to complete the special command itself
        return [{"type": "special"}]

    if cmd in ("\\u", "\\r"):
        return [{"type": "database"}]

    if cmd in ("\\T"):
        return [{"type": "table_format"}]

    if cmd in ["\\f", "\\fs", "\\fd"]:
        return [{"type": "favoritequery"}]

    if cmd in ["\\d", "\\dt", "\\dt+", ".schema", ".indexes"]:
        return [
            {"type": "table", "schema": []},
            {"type": "view", "schema": []},
            {"type": "schema"},
        ]

    if cmd in ["\\.", "source", ".open", ".read"]:
        return [{"type": "file_name"}]

    if cmd in [".import"]:
        # Usage: .import filename table
        if _expecting_arg_idx(arg, text) == 1:
            return [{"type": "file_name"}]
        else:
            return [{"type": "table", "schema": []}]

    return [{"type": "keyword"}, {"type": "special"}]


def _expecting_arg_idx(arg: str, text: str) -> int:
    """Return the index of expecting argument.

    >>> _expecting_arg_idx("./da", ".import ./da")
    1
    >>> _expecting_arg_idx("./data.csv", ".import ./data.csv")
    1
    >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ")
    2
    >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t")
    2
    """
    args = arg.split()
    return len(args) + int(text[-1].isspace())


def suggest_based_on_last_token(
    token: str | Token | None,
    text_before_cursor: str,
    full_text: str,
    identifier: Identifier | None,
) -> list[dict[str, Any]]:
    if isinstance(token, str):
        token_v = token.lower()
    elif isinstance(token, Comparison):
        # If 'token' is a Comparison type such as
        # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
        # token.value on the comparison type will only return the lhs of the
        # comparison. In this case a.id. So we need to do token.tokens to get
        # both sides of the comparison and pick the last token out of that
        # list.
        token_v = token.tokens[-1].value.lower()
    elif isinstance(token, Where):
        # sqlparse groups all tokens from the where clause into a single token
        # list. This means that token.value may be something like
        # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
        # suggestions in complicated where clauses correctly
        prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
        return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
    else:
        assert token is not None
        token_v = token.value.lower()

    def is_operand(x: str | None) -> bool:
        if not x:
            return False
        return any([x.endswith(op) for op in ["+", "-", "*", "/"]])

    if not token:
        return [{"type": "keyword"}, {"type": "special"}]
    elif token_v.endswith("("):
        p = sqlparse.parse(text_before_cursor)[0]

        if p.tokens and isinstance(p.tokens[-1], Where):
            # Four possibilities:
            #  1 - Parenthesized clause like "WHERE foo AND ("
            #        Suggest columns/functions
            #  2 - Function call like "WHERE foo("
            #        Suggest columns/functions
            #  3 - Subquery expression like "WHERE EXISTS ("
            #        Suggest keywords, in order to do a subquery
            #  4 - Subquery OR array comparison like "WHERE foo = ANY("
            #        Suggest columns/functions AND keywords. (If we wanted to be
            #        really fancy, we could suggest only array-typed columns)

            column_suggestions = suggest_based_on_last_token("where", text_before_cursor, full_text, identifier)

            # Check for a subquery expression (cases 3 & 4)
            where = p.tokens[-1]
            idx, prev_tok = where.token_prev(len(where.tokens) - 1)

            if isinstance(prev_tok, Comparison):
                # e.g. "SELECT foo FROM bar WHERE foo = ANY("
                prev_tok = prev_tok.tokens[-1]

            prev_tok = prev_tok.value.lower()
            if prev_tok == "exists":
                return [{"type": "keyword"}]
            else:
                return column_suggestions

        # Get the token before the parens
        idx, prev_tok = p.token_prev(len(p.tokens) - 1)
        if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
            # tbl1 INNER JOIN tbl2 USING (col1, col2)
            tables = extract_tables(full_text)

            # suggest columns that are present in more than one table
            return [{"type": "column", "tables": tables, "drop_unique": True}]
        elif p.token_first().value.lower() == "select":
            # If the lparen is preceded by a space chances are we're about to
            # do a sub-select.
            if last_word(text_before_cursor, "all_punctuations").startswith("("):
                return [{"type": "keyword"}]
        elif p.token_first().value.lower() == "show":
            return [{"type": "show"}]

        # We're probably in a function argument list
        return [{"type": "column", "tables": extract_tables(full_text)}]
    elif token_v in ("set", "order by", "distinct"):
        return [{"type": "column", "tables": extract_tables(full_text)}]
    elif token_v == "as":
        # Don't suggest anything for an alias
        return []
    elif token_v in ("show"):
        return [{"type": "show"}]
    elif token_v in ("to",):
        p = sqlparse.parse(text_before_cursor)[0]
        if p.token_first().value.lower() == "change":
            return [{"type": "change"}]
        else:
            return [{"type": "user"}]
    elif token_v in ("user", "for"):
        return [{"type": "user"}]
    elif token_v in ("select", "where", "having"):
        # Check for a table alias or schema qualification
        parent = (identifier and identifier.get_parent_name()) or []

        tables = extract_tables(full_text)
        if parent:
            tables = [t for t in tables if identifies(parent, *t)]
            return [
                {"type": "column", "tables": tables},
                {"type": "table", "schema": parent},
                {"type": "view", "schema": parent},
                {"type": "function", "schema": parent},
            ]
        else:
            aliases = [alias or table for (schema, table, alias) in tables]
            return [
                {"type": "column", "tables": tables},
                {"type": "function", "schema": []},
                {"type": "alias", "aliases": aliases},
                {"type": "keyword"},
            ]
    elif (token_v.endswith("join") and isinstance(token, Token) and token.is_keyword) or (
        token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
    ):
        schema = (identifier and identifier.get_parent_name()) or []

        # Suggest tables from either the currently-selected schema or the
        # public schema if no schema has been specified
        suggest = [{"type": "table", "schema": schema}]

        if not schema:
            # Suggest schemas
            suggest.insert(0, {"type": "schema"})

        # Only tables can be TRUNCATED, otherwise suggest views
        if token_v != "truncate":
            suggest.append({"type": "view", "schema": schema})

        return suggest

    elif token_v in ("table", "view", "function"):
        # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
        rel_type = token_v
        schema = (identifier and identifier.get_parent_name()) or []
        if schema:
            return [{"type": rel_type, "schema": schema}]
        else:
            return [{"type": "schema"}, {"type": rel_type, "schema": []}]
    elif token_v == "on":
        tables = extract_tables(full_text)  # [(schema, table, alias), ...]
        parent = (identifier and identifier.get_parent_name()) or []
        if parent:
            # "ON parent.<suggestion>"
            # parent can be either a schema name or table alias
            tables = [t for t in tables if identifies(parent, *t)]
            return [
                {"type": "column", "tables": tables},
                {"type": "table", "schema": parent},
                {"type": "view", "schema": parent},
                {"type": "function", "schema": parent},
            ]
        else:
            # ON <suggestion>
            # Use table alias if there is one, otherwise the table name
            aliases = [alias or table for (schema, table, alias) in tables]
            suggest = [{"type": "alias", "aliases": aliases}]

            # The lists of 'aliases' could be empty if we're trying to complete
            # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
            # In that case we just suggest all tables.
            if not aliases:
                suggest.append({"type": "table", "schema": parent})
            return suggest

    elif token_v in ("use", "database", "template", "connect"):
        # "\c <db", "use <db>", "DROP DATABASE <db>",
        # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
        return [{"type": "database"}]
    elif token_v == "tableformat":
        return [{"type": "table_format"}]
    elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
        prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
        if prev_keyword:
            return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
        else:
            return []
    else:
        return [{"type": "keyword"}]


def identifies(id: Any, schema: str | None, table: str, alias: str | None) -> bool:
    return (id == alias) or (id == table) or (schema is not None and (id == schema + "." + table))
