from __future__ import unicode_literals
import re
import sys
import logging
import click
import io
import shlex
import sqlparse
import psycopg
from os.path import expanduser
from .namedqueries import NamedQueries
from . import export
from .main import show_extra_help_command, special_command

NAMED_QUERY_PLACEHOLDERS = frozenset({"$1", "$*", "$@"})

DEFAULT_WATCH_SECONDS = 2

_logger = logging.getLogger(__name__)


@export
def editor_command(command):
    """
    Is this an external editor command?  (\\e or \\ev)

    :param command: string

    Returns the specific external editor command found.
    """
    # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
    # for both conditions.

    stripped = command.strip()
    for sought in ("\\e ", "\\ev ", "\\ef "):
        if stripped.startswith(sought):
            return sought.strip()
    for sought in ("\\e",):
        if stripped.endswith(sought):
            return sought


@export
def get_filename(sql):
    if sql.strip().startswith("\\e"):
        command, _, filename = sql.partition(" ")
        return filename.strip() or None


@export
@show_extra_help_command(
    "\\watch",
    f"\\watch [sec={DEFAULT_WATCH_SECONDS}]",
    "Execute query every `sec` seconds.",
)
def get_watch_command(command):
    match = re.match(r"(.*?)[\s]*\\watch(\s+\d+)?\s*;?\s*$", command, re.DOTALL)
    if match:
        groups = match.groups(default=f"{DEFAULT_WATCH_SECONDS}")
        return groups[0], int(groups[1])
    return None, None


@export
def get_editor_query(sql):
    """Get the query part of an editor command."""
    sql = sql.strip()

    # The reason we can't simply do .strip('\e') is that it strips characters,
    # not a substring. So it'll strip "e" in the end of the sql also!
    # Ex: "select * from style\e" -> "select * from styl".
    pattern = re.compile(r"(^\\e|\\e$)")
    while pattern.search(sql):
        sql = pattern.sub("", sql)

    return sql


@export
def open_external_editor(filename=None, sql=None, editor=None):
    """
    Open external editor, wait for the user to type in his query,
    return the query.
    :return: list with one tuple, query as first element.
    """

    message = None
    filename = filename.strip().split(" ", 1)[0] if filename else None

    sql = sql or ""
    MARKER = "# Type your query above this line.\n"

    # Populate the editor buffer with the partial sql (if available) and a
    # placeholder comment.
    query = click.edit(
        "{sql}\n\n{marker}".format(sql=sql, marker=MARKER),
        filename=filename,
        extension=".sql",
        editor=editor,
    )

    if filename:
        try:
            query = read_from_file(filename)
        except IOError:
            message = "Error reading file: %s." % filename

    if query is not None:
        query = query.split(MARKER, 1)[0].rstrip("\n")
    else:
        # Don't return None for the caller to deal with.
        # Empty string is ok.
        query = sql

    return (query, message)


def read_from_file(path):
    with io.open(expanduser(path), encoding="utf-8") as f:
        contents = f.read()
    return contents


def _index_of_file_name(tokenlist):
    for idx, token in reversed(list(enumerate(tokenlist[:-2]))):
        if token.is_keyword and token.value.upper() in ("TO", "FROM"):
            return idx + 2
    raise Exception("Missing keyword in \\copy command. Either TO or FROM is required.")


@special_command(
    "\\copy",
    "\\copy [tablename] to/from [filename]",
    "Copy data between a file and a table.",
    case_sensitive=False,
)
def copy(cur, pattern, verbose):
    """Copies table data to/from files"""

    # Replace the specified file destination with STDIN or STDOUT
    parsed = sqlparse.parse(pattern)
    tokens = parsed[0].tokens
    idx = _index_of_file_name(tokens)
    file_name = tokens[idx].value
    before_file_name = "".join(t.value for t in tokens[:idx])
    after_file_name = "".join(t.value for t in tokens[idx + 1 :])

    direction = tokens[idx - 2].value.upper()
    replacement_file_name = "STDIN" if direction == "FROM" else "STDOUT"
    query = f"{before_file_name} {replacement_file_name} {after_file_name}"
    open_mode = "r" if direction == "FROM" else "wb"
    if file_name.startswith("'") and file_name.endswith("'"):
        file = io.open(expanduser(file_name.strip("'")), mode=open_mode)
    elif "stdin" in file_name.lower():
        file = sys.stdin.buffer
    elif "stdout" in file_name.lower():
        file = sys.stdout.buffer
    else:
        raise Exception("Enclose filename in single quotes")

    if direction == "FROM":
        with cur.copy("copy " + query) as pgcopy:
            while True:
                data = file.read(8192)
                if not data:
                    break
                pgcopy.write(data)
    else:
        with cur.copy("copy " + query) as pgcopy:
            for data in pgcopy:
                file.write(bytes(data))

    if cur.description:
        headers = [x.name for x in cur.description]
        return [(None, None, headers, cur.statusmessage)]
    else:
        return [(None, None, None, cur.statusmessage)]


def subst_favorite_query_args(query, args):
    """replace positional parameters ($1,$2,...$n) in query."""
    is_query_with_aggregation = ("$*" in query) or ("$@" in query)

    # In case of arguments aggregation we replace all positional arguments until the
    # first one not present in the query. Then we aggregate all the remaining ones and
    # replace the placeholder with them.
    for idx, val in enumerate(args, start=1):
        subst_var = "$" + str(idx)
        if subst_var not in query:
            if is_query_with_aggregation:
                # remove consumed arguments ( - 1 to include current value)
                args = args[idx - 1 :]
                break

            return [
                None,
                "query does not have substitution parameter " + subst_var + ":\n  " + query,
            ]

        query = query.replace(subst_var, val)
    # we consumed all arguments
    else:
        args = []

    if is_query_with_aggregation and not args:
        return [None, "missing substitution for $* or $@ in query:\n" + query]

    if "$*" in query:
        query = query.replace("$*", ", ".join(args))
    elif "$@" in query:
        query = query.replace("$@", ", ".join(map("'{}'".format, args)))

    match = re.search("\\$\\d+", query)
    if match:
        return [
            None,
            "missing substitution for " + match.group(0) + " in query:\n  " + query,
        ]

    return [query, None]


@special_command("\\n", "\\n[+] [name] [param1 param2 ...]", "List or execute named queries.")
def execute_named_query(cur, pattern, **_):
    """Returns (title, rows, headers, status)"""
    if pattern == "":
        return list_named_queries(True)

    params = shlex.split(pattern)
    pattern = params.pop(0)

    query = NamedQueries.instance.get(pattern)
    title = "> {}".format(query)
    if query is None:
        message = "No named query: {}".format(pattern)
        return [(None, None, None, message)]

    try:
        if any(p in query for p in NAMED_QUERY_PLACEHOLDERS):
            query, params = subst_favorite_query_args(query, params)
            if query is None:
                raise Exception("Bad arguments\n" + params)
        cur.execute(query)
    except psycopg.errors.SyntaxError:
        if "%s" in query:
            raise Exception('Bad arguments: please use "$1", "$2", etc. for named queries instead of "%s"')
        else:
            raise
    except (IndexError, TypeError):
        raise Exception("Bad arguments")

    if cur.description:
        headers = [x.name for x in cur.description]
        return [(title, cur, headers, cur.statusmessage)]
    else:
        return [(title, None, None, cur.statusmessage)]


def list_named_queries(verbose):
    """List of all named queries.
    Returns (title, rows, headers, status)"""
    if not verbose:
        rows = [[r] for r in NamedQueries.instance.list()]
        headers = ["Name"]
    else:
        headers = ["Name", "Query"]
        rows = [[r, NamedQueries.instance.get(r)] for r in NamedQueries.instance.list()]

    if not rows:
        status = NamedQueries.instance.usage
    else:
        status = ""
    return [("", rows, headers, status)]


@special_command("\\np", "\\np name_pattern", "Print a named query.")
def get_named_query(pattern, **_):
    """Get a named query that matches name_pattern.

    The named pattern can be a regular expression. Returns (title,
    rows, headers, status)

    """

    usage = "Syntax: \\np name.\n\n" + NamedQueries.instance.usage
    if not pattern:
        return [(None, None, None, usage)]

    name = pattern.strip()
    if not name:
        return [(None, None, None, usage + "Err: A name is required.")]

    headers = ["Name", "Query"]
    rows = [(r, NamedQueries.instance.get(r)) for r in NamedQueries.instance.list() if re.search(name, r)]

    status = ""
    if not rows:
        status = "No match found"

    return [("", rows, headers, status)]


@special_command("\\ns", "\\ns name query", "Save a named query.")
def save_named_query(pattern, **_):
    """Save a new named query.
    Returns (title, rows, headers, status)"""

    usage = "Syntax: \\ns name query.\n\n" + NamedQueries.instance.usage
    if not pattern:
        return [(None, None, None, usage)]

    name, _, query = pattern.partition(" ")

    # If either name or query is missing then print the usage and complain.
    if (not name) or (not query):
        return [(None, None, None, usage + "Err: Both name and query are required.")]

    NamedQueries.instance.save(name, query)
    return [(None, None, None, "Saved.")]


@special_command("\\nd", "\\nd [name]", "Delete a named query.")
def delete_named_query(pattern, **_):
    """Delete an existing named query."""
    usage = "Syntax: \\nd name.\n\n" + NamedQueries.instance.usage
    if not pattern:
        return [(None, None, None, usage)]

    status = NamedQueries.instance.delete(pattern)

    return [(None, None, None, status)]
