"""
/***************************************************************************
Name                 : DB Manager
Description          : Database manager plugin for QGIS
Date                 : May 23, 2011
copyright            : (C) 2011 by Giuseppe Sucameli
email                : brush.tyler@gmail.com

 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

from functools import partial
from qgis.PyQt.QtCore import (
    Qt,
    QObject,
    qDebug,
    QByteArray,
    QMimeData,
    QDataStream,
    QIODevice,
    QFileInfo,
    QAbstractItemModel,
    QModelIndex,
    pyqtSignal,
)
from qgis.PyQt.QtWidgets import QApplication, QMessageBox
from qgis.PyQt.QtGui import QIcon

from .db_plugins import supportedDbTypes, createDbPlugin
from .db_plugins.plugin import BaseError, Table, Database
from .dlg_db_error import DlgDbError
from .gui_utils import GuiUtils

from qgis.core import (
    QgsApplication,
    QgsDataSourceUri,
    QgsVectorLayer,
    QgsRasterLayer,
    QgsMimeDataUtils,
    QgsProviderConnectionException,
    QgsProviderRegistry,
    QgsAbstractDatabaseProviderConnection,
    QgsMessageLog,
)

from qgis.utils import OverrideCursor

try:
    from qgis.core import QgsVectorLayerExporter  # NOQA

    isImportVectorAvail = True
except:
    isImportVectorAvail = False

from osgeo import gdal


class TreeItem(QObject):
    deleted = pyqtSignal()
    changed = pyqtSignal()

    def __init__(self, data, parent=None):
        QObject.__init__(self, parent)
        self.populated = False
        self.itemData = data
        self.childItems = []
        if parent:
            parent.appendChild(self)

    def childRemoved(self):
        self.itemChanged()

    def itemChanged(self):
        self.changed.emit()

    def itemDeleted(self):
        self.deleted.emit()

    def populate(self):
        self.populated = True
        return True

    def getItemData(self):
        return self.itemData

    def appendChild(self, child):
        self.childItems.append(child)
        child.deleted.connect(self.childRemoved)

    def child(self, row):
        return self.childItems[row]

    def removeChild(self, row):
        if row >= 0 and row < len(self.childItems):
            self.childItems[row].itemData.deleteLater()
            self.childItems[row].deleted.disconnect(self.childRemoved)
            del self.childItems[row]

    def childCount(self):
        return len(self.childItems)

    def columnCount(self):
        return 1

    def row(self):
        if self.parent():
            for row, item in enumerate(self.parent().childItems):
                if item is self:
                    return row
        return 0

    def data(self, column):
        return "" if column == 0 else None

    def icon(self):
        return None

    def path(self):
        pathList = []
        if self.parent():
            pathList.extend(self.parent().path())
        pathList.append(self.data(0))
        return pathList


class PluginItem(TreeItem):

    def __init__(self, dbplugin, parent=None):
        TreeItem.__init__(self, dbplugin, parent)

    def populate(self):
        if self.populated:
            return True

        # create items for connections
        for c in self.getItemData().connections():
            ConnectionItem(c, self)

        self.populated = True
        return True

    def data(self, column):
        if column == 0:
            return self.getItemData().typeNameString()
        return None

    def icon(self):
        return self.getItemData().icon()

    def path(self):
        return [self.getItemData().typeName()]


class ConnectionItem(TreeItem):

    def __init__(self, connection, parent=None):
        TreeItem.__init__(self, connection, parent)
        connection.changed.connect(self.itemChanged)
        connection.deleted.connect(self.itemDeleted)

        # load (shared) icon with first instance of table item
        if not hasattr(ConnectionItem, "connectedIcon"):
            ConnectionItem.connectedIcon = GuiUtils.get_icon("plugged")
            ConnectionItem.disconnectedIcon = GuiUtils.get_icon("unplugged")

    def data(self, column):
        if column == 0:
            return self.getItemData().connectionName()
        return None

    def icon(self):
        return self.getItemData().connectionIcon()

    def populate(self):
        if self.populated:
            return True

        connection = self.getItemData()
        if connection.database() is None:
            # connect to database
            try:
                if not connection.connect():
                    return False

            except BaseError as e:
                DlgDbError.showError(e, None)
                return False

        database = connection.database()
        database.changed.connect(self.itemChanged)
        database.deleted.connect(self.itemDeleted)

        schemas = database.schemas()
        if schemas is not None:
            for s in schemas:
                SchemaItem(s, self)
        else:
            tables = database.tables()
            for t in tables:
                TableItem(t, self)

        self.populated = True
        return True

    def isConnected(self):
        return self.getItemData().database() is not None

        # def icon(self):
        #       return self.connectedIcon if self.isConnected() else self.disconnectedIcon


class SchemaItem(TreeItem):

    def __init__(self, schema, parent):
        TreeItem.__init__(self, schema, parent)
        schema.changed.connect(self.itemChanged)
        schema.deleted.connect(self.itemDeleted)

        # load (shared) icon with first instance of schema item
        if not hasattr(SchemaItem, "schemaIcon"):
            SchemaItem.schemaIcon = GuiUtils.get_icon("namespace")

    def data(self, column):
        if column == 0:
            return self.getItemData().name
        return None

    def icon(self):
        return self.schemaIcon

    def populate(self):
        if self.populated:
            return True

        for t in self.getItemData().tables():
            TableItem(t, self)

        self.populated = True
        return True


class TableItem(TreeItem):

    def __init__(self, table, parent):
        TreeItem.__init__(self, table, parent)
        table.changed.connect(self.itemChanged)
        table.deleted.connect(self.itemDeleted)
        self.populate()

        # load (shared) icon with first instance of table item
        if not hasattr(TableItem, "tableIcon"):
            TableItem.tableIcon = QgsApplication.getThemeIcon("/mIconTableLayer.svg")
            TableItem.viewIcon = GuiUtils.get_icon("view")
            TableItem.viewMaterializedIcon = GuiUtils.get_icon("view_materialized")
            TableItem.layerPointIcon = QgsApplication.getThemeIcon(
                "/mIconPointLayer.svg"
            )
            TableItem.layerLineIcon = QgsApplication.getThemeIcon("/mIconLineLayer.svg")
            TableItem.layerPolygonIcon = QgsApplication.getThemeIcon(
                "/mIconPolygonLayer.svg"
            )
            TableItem.layerRasterIcon = QgsApplication.getThemeIcon(
                "/mIconRasterLayer.svg"
            )
            TableItem.layerUnknownIcon = GuiUtils.get_icon("layer_unknown")

    def data(self, column):
        if column == 0:
            return self.getItemData().name
        elif column == 1:
            if self.getItemData().type == Table.VectorType:
                return self.getItemData().geomType
        return None

    def icon(self):
        if self.getItemData().type == Table.VectorType:
            geom_type = self.getItemData().geomType
            if geom_type is not None:
                if geom_type.find("POINT") != -1:
                    return self.layerPointIcon
                elif geom_type.find("LINESTRING") != -1 or geom_type in (
                    "CIRCULARSTRING",
                    "COMPOUNDCURVE",
                    "MULTICURVE",
                ):
                    return self.layerLineIcon
                elif geom_type.find("POLYGON") != -1 or geom_type == "MULTISURFACE":
                    return self.layerPolygonIcon
                return self.layerUnknownIcon

        elif self.getItemData().type == Table.RasterType:
            return self.layerRasterIcon

        if self.getItemData().isView:
            if (
                hasattr(self.getItemData(), "_relationType")
                and self.getItemData()._relationType == "m"
            ):
                return self.viewMaterializedIcon
            else:
                return self.viewIcon
        return self.tableIcon

    def path(self):
        pathList = []
        if self.parent():
            pathList.extend(self.parent().path())

        if self.getItemData().type == Table.VectorType:
            pathList.append(f"{self.data(0)}::{self.getItemData().geomColumn}")
        else:
            pathList.append(self.data(0))

        return pathList


class DBModel(QAbstractItemModel):
    importVector = pyqtSignal(QgsVectorLayer, Database, QgsDataSourceUri, QModelIndex)
    notPopulated = pyqtSignal(QModelIndex)

    def __init__(self, parent=None):
        global isImportVectorAvail

        QAbstractItemModel.__init__(self, parent)
        self.treeView = parent
        self.header = [self.tr("Databases")]

        if isImportVectorAvail:
            self.importVector.connect(self.vectorImport)

        self.hasSpatialiteSupport = "spatialite" in supportedDbTypes()
        self.hasGPKGSupport = "gpkg" in supportedDbTypes()

        self.rootItem = TreeItem(None, None)
        for dbtype in supportedDbTypes():
            dbpluginclass = createDbPlugin(dbtype)
            item = PluginItem(dbpluginclass, self.rootItem)
            item.changed.connect(partial(self.refreshItem, item))

    def refreshItem(self, item):
        if isinstance(item, TreeItem):
            # find the index for the tree item using the path
            index = self._rPath2Index(item.path())
        else:
            # find the index for the db item
            index = self._rItem2Index(item)
        if index.isValid():
            self._refreshIndex(index)
        else:
            qDebug("invalid index")

    def _rItem2Index(self, item, parent=None):
        if parent is None:
            parent = QModelIndex()
        if item == self.getItem(parent):
            return parent

        if not parent.isValid() or parent.internalPointer().populated:
            for i in range(self.rowCount(parent)):
                index = self.index(i, 0, parent)
                index = self._rItem2Index(item, index)
                if index.isValid():
                    return index

        return QModelIndex()

    def _rPath2Index(self, path, parent=None, n=0):
        if parent is None:
            parent = QModelIndex()
        if path is None or len(path) == 0:
            return parent

        for i in range(self.rowCount(parent)):
            index = self.index(i, 0, parent)
            if self._getPath(index)[n] == path[0]:
                return self._rPath2Index(path[1:], index, n + 1)

        return parent

    def getItem(self, index):
        if not index.isValid():
            return None
        return index.internalPointer().getItemData()

    def _getPath(self, index):
        if not index.isValid():
            return None
        return index.internalPointer().path()

    def columnCount(self, parent):
        return 1

    def data(self, index, role):
        if not index.isValid():
            return None

        if role == Qt.ItemDataRole.DecorationRole and index.column() == 0:
            icon = index.internalPointer().icon()
            if icon:
                return icon

        if role != Qt.ItemDataRole.DisplayRole and role != Qt.ItemDataRole.EditRole:
            return None

        retval = index.internalPointer().data(index.column())
        return retval

    def flags(self, index):
        global isImportVectorAvail

        if not index.isValid():
            return Qt.ItemFlag.NoItemFlags

        flags = Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable

        if index.column() == 0:
            item = index.internalPointer()

            if isinstance(item, SchemaItem) or (
                isinstance(item, TableItem)
                and not (
                    self.hasGPKGSupport
                    and item.getItemData().type == Table.RasterType
                    and int(gdal.VersionInfo()) < 3100000
                )
            ):
                flags |= Qt.ItemFlag.ItemIsEditable

            if isinstance(item, TableItem):
                flags |= Qt.ItemFlag.ItemIsDragEnabled

            # vectors/tables can be dropped on connected databases to be imported
            if isImportVectorAvail:
                if isinstance(item, ConnectionItem) and item.populated:
                    flags |= Qt.ItemFlag.ItemIsDropEnabled

                if isinstance(item, (SchemaItem, TableItem)):
                    flags |= Qt.ItemFlag.ItemIsDropEnabled

            # SL/Geopackage db files can be dropped everywhere in the tree
            if self.hasSpatialiteSupport or self.hasGPKGSupport:
                flags |= Qt.ItemFlag.ItemIsDropEnabled

        return flags

    def headerData(self, section, orientation, role):
        if (
            orientation == Qt.Orientation.Horizontal
            and role == Qt.ItemDataRole.DisplayRole
            and section < len(self.header)
        ):
            return self.header[section]
        return None

    def index(self, row, column, parent):
        if not self.hasIndex(row, column, parent):
            return QModelIndex()

        parentItem = parent.internalPointer() if parent.isValid() else self.rootItem
        childItem = parentItem.child(row)
        if childItem:
            return self.createIndex(row, column, childItem)
        return QModelIndex()

    def parent(self, index):
        if not index.isValid():
            return QModelIndex()

        childItem = index.internalPointer()
        parentItem = childItem.parent()

        if parentItem == self.rootItem:
            return QModelIndex()

        return self.createIndex(parentItem.row(), 0, parentItem)

    def rowCount(self, parent):
        parentItem = parent.internalPointer() if parent.isValid() else self.rootItem
        if not parentItem.populated:
            self._refreshIndex(parent, True)
        return parentItem.childCount()

    def hasChildren(self, parent):
        parentItem = parent.internalPointer() if parent.isValid() else self.rootItem
        return parentItem.childCount() > 0 or not parentItem.populated

    def setData(self, index, value, role):
        if role != Qt.ItemDataRole.EditRole or index.column() != 0:
            return False

        item = index.internalPointer()
        new_value = str(value)

        if isinstance(item, SchemaItem) or isinstance(item, TableItem):
            obj = item.getItemData()

            # rename schema or table or view
            if new_value == obj.name:
                return False

            with OverrideCursor(Qt.CursorShape.WaitCursor):
                try:
                    obj.rename(new_value)
                    self._onDataChanged(index)
                except BaseError as e:
                    DlgDbError.showError(e, self.treeView)
                    return False
                else:
                    return True

        return False

    def removeRows(self, row, count, parent):
        self.beginRemoveRows(parent, row, count + row - 1)
        item = parent.internalPointer()
        for i in range(row, count + row):
            item.removeChild(row)
        self.endRemoveRows()

    def _refreshIndex(self, index, force=False):
        with OverrideCursor(Qt.CursorShape.WaitCursor):
            try:
                item = index.internalPointer() if index.isValid() else self.rootItem
                prevPopulated = item.populated
                if prevPopulated:
                    self.removeRows(0, self.rowCount(index), index)
                    item.populated = False
                if prevPopulated or force:
                    if item.populate():
                        for child in item.childItems:
                            child.changed.connect(partial(self.refreshItem, child))
                        self._onDataChanged(index)
                    else:
                        self.notPopulated.emit(index)

            except BaseError:
                item.populated = False

    def _onDataChanged(self, indexFrom, indexTo=None):
        if indexTo is None:
            indexTo = indexFrom
        self.dataChanged.emit(indexFrom, indexTo)

    QGIS_URI_MIME = "application/x-vnd.qgis.qgis.uri"

    def mimeTypes(self):
        return ["text/uri-list", self.QGIS_URI_MIME]

    def mimeData(self, indexes):
        mimeData = QMimeData()
        encodedData = QByteArray()

        stream = QDataStream(encodedData, QIODevice.OpenModeFlag.WriteOnly)

        for index in indexes:
            if not index.isValid():
                continue
            if not isinstance(index.internalPointer(), TableItem):
                continue
            table = self.getItem(index)
            stream.writeQString(table.mimeUri())

        mimeData.setData(self.QGIS_URI_MIME, encodedData)
        return mimeData

    def dropMimeData(self, data, action, row, column, parent):
        global isImportVectorAvail

        if action == Qt.DropAction.IgnoreAction:
            return True

        # vectors/tables to be imported must be dropped on connected db, schema or table
        canImportLayer = (
            isImportVectorAvail
            and parent.isValid()
            and (
                isinstance(parent.internalPointer(), (SchemaItem, TableItem))
                or (
                    isinstance(parent.internalPointer(), ConnectionItem)
                    and parent.internalPointer().populated
                )
            )
        )

        added = 0

        if data.hasUrls():
            for u in data.urls():
                filename = u.toLocalFile()
                if filename == "":
                    continue

                if self.hasSpatialiteSupport:
                    from .db_plugins.spatialite.connector import SpatiaLiteDBConnector

                    if SpatiaLiteDBConnector.isValidDatabase(filename):
                        # retrieve the SL plugin tree item using its path
                        index = self._rPath2Index(["spatialite"])
                        if not index.isValid():
                            continue
                        item = index.internalPointer()

                        conn_name = QFileInfo(filename).fileName()
                        uri = QgsDataSourceUri()
                        uri.setDatabase(filename)
                        item.getItemData().addConnection(conn_name, uri)
                        item.changed.emit()
                        added += 1
                        continue

                if canImportLayer:
                    if QgsRasterLayer.isValidRasterFileName(filename):
                        layerType = "raster"
                        providerKey = "gdal"
                    else:
                        layerType = "vector"
                        providerKey = "ogr"

                    layerName = QFileInfo(filename).completeBaseName()
                    if self.importLayer(
                        layerType, providerKey, layerName, filename, parent
                    ):
                        added += 1

        if data.hasFormat(self.QGIS_URI_MIME):
            for uri in QgsMimeDataUtils.decodeUriList(data):
                if canImportLayer:
                    if self.importLayer(
                        uri.layerType, uri.providerKey, uri.name, uri.uri, parent
                    ):
                        added += 1

        return added > 0

    def importLayer(self, layerType, providerKey, layerName, uriString, parent):
        global isImportVectorAvail

        if not isImportVectorAvail:
            return False

        if layerType == "raster":
            return False  # not implemented yet
            inLayer = QgsRasterLayer(uriString, layerName, providerKey)
        else:
            inLayer = QgsVectorLayer(uriString, layerName, providerKey)

        if not inLayer.isValid():
            # invalid layer
            QMessageBox.warning(
                None,
                self.tr("Invalid layer"),
                self.tr("Unable to load the layer {0}").format(inLayer.name()),
            )
            return False

        # retrieve information about the new table's db and schema
        outItem = parent.internalPointer()
        outObj = outItem.getItemData()
        outDb = outObj.database()
        outSchema = None
        if isinstance(outItem, SchemaItem):
            outSchema = outObj
        elif isinstance(outItem, TableItem):
            outSchema = outObj.schema()

        # toIndex will point to the parent item of the new table
        toIndex = parent
        if isinstance(toIndex.internalPointer(), TableItem):
            toIndex = toIndex.parent()

        if inLayer.type() == inLayer.VectorLayer:
            # create the output uri
            schema = (
                outSchema.name
                if outDb.schemas() is not None and outSchema is not None
                else ""
            )
            pkCol = geomCol = ""

            # default pk and geom field name value
            if providerKey in ["postgres", "spatialite"]:
                inUri = QgsDataSourceUri(inLayer.source())
                pkCol = inUri.keyColumn()
                geomCol = inUri.geometryColumn()

            outUri = outDb.uri()
            outUri.setDataSource(schema, layerName, geomCol, "", pkCol)

            self.importVector.emit(inLayer, outDb, outUri, toIndex)
            return True

        return False

    def vectorImport(self, inLayer, outDb, outUri, parent):
        global isImportVectorAvail

        if not isImportVectorAvail:
            return False

        try:
            from .dlg_import_vector import DlgImportVector

            dlg = DlgImportVector(inLayer, outDb, outUri)
            QApplication.restoreOverrideCursor()
            if dlg.exec():
                self._refreshIndex(parent)
        finally:
            inLayer.deleteLater()
