import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import (
    _literal_as_text,
    ClauseElement,
    ColumnElement,
    Executable,
    FunctionElement
)
from sqlalchemy.sql.functions import GenericFunction

from sqlalchemy_utils.functions.orm import quote


class explain(Executable, ClauseElement):
    """
    Define EXPLAIN element.

    http://www.postgresql.org/docs/devel/static/sql-explain.html
    """
    def __init__(
        self,
        stmt,
        analyze=False,
        verbose=False,
        costs=True,
        buffers=False,
        timing=True,
        format='text'
    ):
        self.statement = _literal_as_text(stmt)
        self.analyze = analyze
        self.verbose = verbose
        self.costs = costs
        self.buffers = buffers
        self.timing = timing
        self.format = format


class explain_analyze(explain):
    def __init__(self, stmt, **kwargs):
        super(explain_analyze, self).__init__(
            stmt,
            analyze=True,
            **kwargs
        )


@compiles(explain, 'postgresql')
def pg_explain(element, compiler, **kw):
    text = "EXPLAIN "
    options = []
    if element.analyze:
        options.append('ANALYZE true')
    if not element.timing:
        options.append('TIMING false')
    if element.buffers:
        options.append('BUFFERS true')
    if element.format != 'text':
        options.append('FORMAT %s' % element.format)
    if element.verbose:
        options.append('VERBOSE true')
    if not element.costs:
        options.append('COSTS false')
    if options:
        text += '(%s) ' % ', '.join(options)
    text += compiler.process(element.statement)
    return text


class array_get(FunctionElement):
    name = 'array_get'


@compiles(array_get)
def compile_array_get(element, compiler, **kw):
    args = list(element.clauses)
    if len(args) != 2:
        raise Exception(
            "Function 'array_get' expects two arguments (%d given)." %
            len(args)
        )

    if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
        raise Exception(
            "Second argument should be an integer."
        )
    return '(%s)[%s]' % (
        compiler.process(args[0]),
        sa.text(str(args[1].value + 1))
    )


class row_to_json(GenericFunction):
    name = 'row_to_json'
    type = postgresql.JSON


@compiles(row_to_json, 'postgresql')
def compile_row_to_json(element, compiler, **kw):
    return "%s(%s)" % (element.name, compiler.process(element.clauses))


class json_array_length(GenericFunction):
    name = 'json_array_length'
    type = sa.Integer


@compiles(json_array_length, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
    return "%s(%s)" % (element.name, compiler.process(element.clauses))


class array_agg(GenericFunction):
    name = 'array_agg'
    type = postgresql.ARRAY

    def __init__(self, arg, default=None, **kw):
        self.type = postgresql.ARRAY(arg.type)
        self.default = default
        GenericFunction.__init__(self, arg, **kw)


@compiles(array_agg, 'postgresql')
def compile_array_agg(element, compiler, **kw):
    compiled = "%s(%s)" % (element.name, compiler.process(element.clauses))
    if element.default is None:
        return compiled
    return str(sa.func.coalesce(
        sa.text(compiled),
        sa.cast(postgresql.array(element.default), element.type)
    ).compile(compiler))


class Asterisk(ColumnElement):
    def __init__(self, selectable):
        self.selectable = selectable


@compiles(Asterisk)
def compile_asterisk(element, compiler, **kw):
    return '%s.*' % quote(compiler.dialect, element.selectable.name)
