from dataclasses import dataclass
import datetime
from typing import cast, Callable, List, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from sqlite_utils.db import Database, Table


class Migrations:
    migrations_table = "_sqlite_migrations"

    @dataclass
    class _Migration:
        name: str
        fn: Callable

    @dataclass
    class _AppliedMigration:
        name: str
        applied_at: datetime.datetime

    def __init__(self, name: str):
        """
        :param name: The name of the migration set. This should be unique.
        """
        self.name = name
        self._migrations: List[Migrations._Migration] = []

    def __call__(self, *, name: Optional[str] = None) -> Callable:
        """
        :param name: The name to use for this migration - if not provided,
          the name of the function will be used
        """

        def inner(func: Callable) -> Callable:
            self._migrations.append(self._Migration(name or func.__name__, func))
            return func

        return inner

    def pending(self, db: "Database") -> List["Migrations._Migration"]:
        """
        Return a list of pending migrations.
        """
        self.ensure_migrations_table(db)
        already_applied = {
            r["name"]
            for r in db[self.migrations_table].rows_where(
                "migration_set = ?", [self.name]
            )
        }
        return [
            migration
            for migration in self._migrations
            if migration.name not in already_applied
        ]

    def applied(self, db: "Database") -> List["Migrations._AppliedMigration"]:
        """
        Return a list of applied migrations.
        """
        self.ensure_migrations_table(db)
        return [
            self._AppliedMigration(name=row["name"], applied_at=row["applied_at"])
            for row in db[self.migrations_table].rows_where(
                "migration_set = ?", [self.name]
            )
        ]

    def apply(self, db: "Database", *, stop_before: Optional[str] = None):
        """
        Apply any pending migrations to the database.
        """
        self.ensure_migrations_table(db)
        for migration in self.pending(db):
            name = migration.name
            if name == stop_before:
                return
            migration.fn(db)
            _table(db, self.migrations_table).insert(
                {
                    "migration_set": self.name,
                    "name": name,
                    "applied_at": str(datetime.datetime.now(datetime.timezone.utc)),
                }
            )

    def ensure_migrations_table(self, db: "Database"):
        """
        Ensure _sqlite_migrations table exists and has the correct schema
        """
        table = _table(db, self.migrations_table)
        if not table.exists():
            table.create(
                {
                    "migration_set": str,
                    "name": str,
                    "applied_at": str,
                },
                pk=("migration_set", "name"),
            )
        elif table.pks != ["migration_set", "name"]:
            # This has the old primary key scheme, upgrade it
            table.transform(pk=("migration_set", "name"))

    def __repr__(self):
        return "<Migrations '{}': [{}]>".format(
            self.name, ", ".join(m.name for m in self._migrations)
        )


def _table(db: "Database", name: str) -> "Table":
    # mypy workaround
    return cast("Table", db[name])
