import asyncio
import logging
import os
import shutil
import sys
from datetime import datetime
from pathlib import Path
from typing import Any

import click

if sys.version_info >= (3, 11):
    import tomllib
else:
    import tomli as tomllib

from beanie.migrations import template
from beanie.migrations.database import DBHandler
from beanie.migrations.models import RunningDirections, RunningMode
from beanie.migrations.runner import MigrationNode

logging.basicConfig(format="%(message)s", level=logging.INFO)


class MigrationSettings:
    def __init__(self, **kwargs: Any):
        self.direction = (
            kwargs.get("direction")
            or self.get_env_value("direction")
            or self.get_from_toml("direction")
            or RunningDirections.FORWARD
        )

        self.distance = int(
            kwargs.get("distance")
            or self.get_env_value("distance")
            or self.get_from_toml("distance")
            or 0
        )
        self.connection_uri = str(
            kwargs.get("connection_uri")
            or self.get_env_value("connection_uri")
            or self.get_from_toml("connection_uri")
        )
        self.database_name = str(
            kwargs.get("database_name")
            or self.get_env_value("database_name")
            or self.get_from_toml("database_name")
        )
        self.path = Path(
            kwargs.get("path")
            or self.get_env_value("path")
            or self.get_from_toml("path")
        )
        self.allow_index_dropping = bool(
            kwargs.get("allow_index_dropping")
            or self.get_env_value("allow_index_dropping")
            or self.get_from_toml("allow_index_dropping")
            or False
        )
        self.use_transaction = bool(kwargs.get("use_transaction"))

    @staticmethod
    def get_env_value(field_name) -> Any:
        if field_name == "connection_uri":
            value = (
                os.environ.get("BEANIE_URI")
                or os.environ.get("BEANIE_CONNECTION_URI")
                or os.environ.get("BEANIE_CONNECTION_STRING")
                or os.environ.get("BEANIE_MONGODB_DSN")
                or os.environ.get("BEANIE_MONGODB_URI")
                or os.environ.get("beanie_uri")
                or os.environ.get("beanie_connection_uri")
                or os.environ.get("beanie_connection_string")
                or os.environ.get("beanie_mongodb_dsn")
                or os.environ.get("beanie_mongodb_uri")
            )
        elif field_name == "database_name":
            value = (
                os.environ.get("BEANIE_DB")
                or os.environ.get("BEANIE_DB_NAME")
                or os.environ.get("BEANIE_DATABASE_NAME")
                or os.environ.get("beanie_db")
                or os.environ.get("beanie_db_name")
                or os.environ.get("beanie_database_name")
            )
        else:
            value = os.environ.get(
                f"BEANIE_{field_name.upper()}"
            ) or os.environ.get(f"beanie_{field_name.lower()}")
        return value

    @staticmethod
    def get_from_toml(field_name) -> Any:
        path = Path("pyproject.toml")
        if path.is_file():
            with path.open("rb") as f:
                toml_data = tomllib.load(f)
            val = (
                toml_data.get("tool", {})
                .get("beanie", {})
                .get("migrations", {})
            )
        else:
            val = {}
        return val.get(field_name)


@click.group()
def migrations():
    pass


async def run_migrate(settings: MigrationSettings):
    DBHandler.set_db(settings.connection_uri, settings.database_name)
    root = await MigrationNode.build(settings.path)
    mode = RunningMode(
        direction=settings.direction, distance=settings.distance
    )
    await root.run(
        mode=mode,
        allow_index_dropping=settings.allow_index_dropping,
        use_transaction=settings.use_transaction,
    )

    # Cleanup
    client = DBHandler.get_cli()
    if client:
        await client.close()


@migrations.command()
@click.option(
    "--forward",
    "direction",
    required=False,
    flag_value="FORWARD",
    help="Roll the migrations forward. This is default",
)
@click.option(
    "--backward",
    "direction",
    required=False,
    flag_value="BACKWARD",
    help="Roll the migrations backward",
)
@click.option(
    "-d",
    "--distance",
    required=False,
    help="How many migrations should be done since the current? "
    "0 - all the migrations. Default is 0",
)
@click.option(
    "-uri",
    "--connection-uri",
    required=False,
    type=str,
    help="MongoDB connection URI",
)
@click.option(
    "-db", "--database_name", required=False, type=str, help="DataBase name"
)
@click.option(
    "-p",
    "--path",
    required=False,
    type=str,
    help="Path to the migrations directory",
)
@click.option(
    "--allow-index-dropping/--forbid-index-dropping",
    required=False,
    default=False,
    help="if allow-index-dropping is set, Beanie will drop indexes from your collection",
)
@click.option(
    "--use-transaction/--no-use-transaction",
    required=False,
    default=True,
    help="Enable or disable the use of transactions during migration. "
    "When enabled (--use-transaction), Beanie uses transactions for migration, "
    "which necessitates a replica set. When disabled (--no-use-transaction), "
    "migrations occur without transactions.",
)
def migrate(
    direction,
    distance,
    connection_uri,
    database_name,
    path,
    allow_index_dropping,
    use_transaction,
):
    settings_kwargs = {}
    if direction:
        settings_kwargs["direction"] = direction
    if distance:
        settings_kwargs["distance"] = distance
    if connection_uri:
        settings_kwargs["connection_uri"] = connection_uri
    if database_name:
        settings_kwargs["database_name"] = database_name
    if path:
        settings_kwargs["path"] = path
    if allow_index_dropping:
        settings_kwargs["allow_index_dropping"] = allow_index_dropping
    settings_kwargs["use_transaction"] = use_transaction
    settings = MigrationSettings(**settings_kwargs)

    asyncio.run(run_migrate(settings))


@migrations.command()
@click.option("-n", "--name", required=True, type=str, help="Migration name")
@click.option(
    "-p",
    "--path",
    required=True,
    type=str,
    help="Path to the migrations directory",
)
def new_migration(name, path):
    path = Path(path)
    ts_string = datetime.now().strftime("%Y%m%d%H%M%S")
    file_name = f"{ts_string}_{name}.py"

    shutil.copy(template.__file__, path / file_name)


if __name__ == "__main__":
    migrations()
