# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import re
import psycopg
from psycopg.types import TypeInfo
from psycopg.adapt import Loader
from psycopg import sql
from psycopg.client_cursor import ClientCursor
from .exceptions import *
from .builder import parseAgeValue


_EXCEPTION_NoConnection = NoConnection()
_EXCEPTION_GraphNotSet = GraphNotSet()

WHITESPACE = re.compile('\s')


class AgeDumper(psycopg.adapt.Dumper):
    def dump(self, obj: Any) -> bytes | bytearray | memoryview:
        pass    
    
    
class AgeLoader(psycopg.adapt.Loader):    
    def load(self, data: bytes | bytearray | memoryview) -> Any | None:
        if isinstance(data, memoryview):
            data_bytes = data.tobytes()
        else:
            data_bytes = data

        return parseAgeValue(data_bytes.decode('utf-8'))


def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False):
    with conn.cursor() as cursor:
        if load_from_plugins:
            cursor.execute("LOAD '$libdir/plugins/age';")
        else:
            cursor.execute("LOAD 'age';")

        cursor.execute("SET search_path = ag_catalog, '$user', public;")

        ag_info = TypeInfo.fetch(conn, 'agtype')

        if not ag_info:
            raise AgeNotSet()

        conn.adapters.register_loader(ag_info.oid, AgeLoader)
        conn.adapters.register_loader(ag_info.array_oid, AgeLoader)

        # Check graph exists
        if graphName != None:
            checkGraphCreated(conn, graphName)

# Create the graph, if it does not exist
def checkGraphCreated(conn:psycopg.connection, graphName:str):
    with conn.cursor() as cursor:
        cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName)))
        if cursor.fetchone()[0] == 0:
            cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName)))
            conn.commit()


def deleteGraph(conn:psycopg.connection, graphName:str):
    with conn.cursor() as cursor:
        cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName)))
        conn.commit()


def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
    if graphName == None:
        raise _EXCEPTION_GraphNotSet
    
    columnExp=[]
    if columns != None and len(columns) > 0:
        for col in columns:
            if col.strip() == '':
                continue
            elif WHITESPACE.search(col) != None:
                columnExp.append(col)
            else:
                columnExp.append(col + " agtype")
    else:
        columnExp.append('v agtype')

    stmtArr = []
    stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
    stmtArr.append(','.join(columnExp))
    stmtArr.append(");")
    return "".join(stmtArr)

def execSql(conn:psycopg.connection, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
    if conn == None or conn.closed:
        raise _EXCEPTION_NoConnection
    
    cursor = conn.cursor()
    try:
        cursor.execute(stmt, params)
        if commit:
            conn.commit()

        return cursor
    except SyntaxError as cause:
        conn.rollback()
        raise cause
    except Exception as cause:
        conn.rollback()
        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)


def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) -> psycopg.cursor :
    return execSql(conn, stmt, False, params)

# Execute cypher statement and return cursor.
# If cypher statement changes data (create, set, remove),
# You must commit session(ag.commit())
# (Otherwise the execution cannot make any effect.)
def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
    if conn == None or conn.closed:
        raise _EXCEPTION_NoConnection

    cursor = conn.cursor()
    #clean up the string for mogrification
    cypherStmt = cypherStmt.replace("\n", "")
    cypherStmt = cypherStmt.replace("\t", "")
    cypher = str(cursor.mogrify(cypherStmt, params))
    cypher = cypher.strip()

    preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"

    cursor = conn.cursor()
    try:
        cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
    except SyntaxError as cause:
        conn.rollback()
        raise cause
    except Exception as cause:
        conn.rollback()
        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause)

    stmt = buildCypher(graphName, cypher, cols)

    cursor = conn.cursor()
    try:
        cursor.execute(stmt)
        return cursor
    except SyntaxError as cause:
        conn.rollback()
        raise cause
    except Exception as cause:
        conn.rollback()
        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)


def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
    #clean up the string for mogrification
    cypherStmt = cypherStmt.replace("\n", "")
    cypherStmt = cypherStmt.replace("\t", "")
    cypher = str(cursor.mogrify(cypherStmt, params))
    cypher = cypher.strip()

    preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
    cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))

    stmt = buildCypher(graphName, cypher, cols)
    cursor.execute(stmt)


# def execCypherWithReturn(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
#     stmt = buildCypher(graphName, cypherStmt, columns)
#     return execSql(conn, stmt, False, params)

# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
#     return execCypherWithReturn(conn, graphName, cypherStmt, columns, params)


class Age:
    def __init__(self):
        self.connection = None    # psycopg connection]
        self.graphName = None

    # Connect to PostgreSQL Server and establish session and type extension environment.
    def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor,
                load_from_plugins:bool=False, **kwargs):
        conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
        setUpAge(conn, graph, load_from_plugins)
        self.connection = conn
        self.graphName = graph
        return self

    def close(self):
        self.connection.close()

    def setGraph(self, graph:str):
        checkGraphCreated(self.connection, graph)
        self.graphName = graph
        return self

    def commit(self):
        self.connection.commit()

    def rollback(self):
        self.connection.rollback()

    def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
        return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params)

    def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor :
        return cypher(cursor, self.graphName, cypherStmt, cols=cols, params=params)

    # def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
    #     return execSql(self.connection, stmt, commit, params)


    # def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor :
    #     return execCypher(self.connection, self.graphName, cypherStmt, commit, params)

    # def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
    #     return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params)

    # def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor :
    #     return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)
