import asyncio
import logging
from typing import Any, Dict, Generator, List, Optional, Tuple

from knot_resolver.constants import PROMETHEUS_LIB
from knot_resolver.controller.interface import KresID
from knot_resolver.controller.registered_workers import get_registered_workers_kresids
from knot_resolver.datamodel.config_schema import KresConfig
from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
from knot_resolver.utils import compat
from knot_resolver.utils.functional import Result

from .collect import collect_kresd_workers_metrics

logger = logging.getLogger(__name__)

if PROMETHEUS_LIB:
    from prometheus_client import exposition  # type: ignore
    from prometheus_client.bridge.graphite import GraphiteBridge  # type: ignore
    from prometheus_client.core import (
        REGISTRY,
        CounterMetricFamily,
        GaugeMetricFamily,  # type: ignore
        HistogramMetricFamily,
        Metric,
    )

    _graphite_bridge: Optional[GraphiteBridge] = None

    _metrics_collector: Optional["KresPrometheusMetricsCollector"] = None

    def _counter(name: str, description: str, label: Tuple[str, str], value: float) -> CounterMetricFamily:
        c = CounterMetricFamily(name, description, labels=(label[0],))
        c.add_metric((label[1],), value)  # type: ignore
        return c

    def _gauge(name: str, description: str, label: Tuple[str, str], value: float) -> GaugeMetricFamily:
        c = GaugeMetricFamily(name, description, labels=(label[0],))
        c.add_metric((label[1],), value)  # type: ignore
        return c

    def _histogram(
        name: str, description: str, label: Tuple[str, str], buckets: List[Tuple[str, int]], sum_value: float
    ) -> HistogramMetricFamily:
        c = HistogramMetricFamily(name, description, labels=(label[0],))
        c.add_metric((label[1],), buckets, sum_value=sum_value)  # type: ignore
        return c

    def _parse_resolver_metrics(instance_id: "KresID", metrics: Any) -> Generator[Metric, None, None]:
        sid = str(instance_id)

        # response latency histogram
        bucket_names_in_resolver = ("1ms", "10ms", "50ms", "100ms", "250ms", "500ms", "1000ms", "1500ms", "slow")
        bucket_names_in_prometheus = ("0.001", "0.01", "0.05", "0.1", "0.25", "0.5", "1.0", "1.5", "+Inf")

        # add smaller bucket counts
        def _bucket_count(answer: Dict[str, int], duration: str) -> int:
            index = bucket_names_in_resolver.index(duration)
            return sum([int(answer[bucket_names_in_resolver[i]]) for i in range(index + 1)])

        yield _histogram(
            "resolver_response_latency",
            "Time it takes to respond to queries in seconds",
            label=("instance_id", sid),
            buckets=[
                (bnp, _bucket_count(metrics["answer"], duration))
                for bnp, duration in zip(bucket_names_in_prometheus, bucket_names_in_resolver)
            ],
            sum_value=metrics["answer"]["sum_ms"] / 1_000,
        )

        # "request" metrics
        yield _counter(
            "resolver_request_total",
            "total number of DNS requests (including internal client requests)",
            label=("instance_id", sid),
            value=metrics["request"]["total"],
        )
        yield _counter(
            "resolver_request_total4",
            "total number of IPv4 DNS requests",
            label=("instance_id", sid),
            value=metrics["request"]["total4"],
        )
        yield _counter(
            "resolver_request_total6",
            "total number of IPv6 DNS requests",
            label=("instance_id", sid),
            value=metrics["request"]["total6"],
        )
        yield _counter(
            "resolver_request_internal",
            "number of internal requests generated by Knot Resolver (e.g. DNSSEC trust anchor updates)",
            label=("instance_id", sid),
            value=metrics["request"]["internal"],
        )
        yield _counter(
            "resolver_request_udp",
            "number of external requests received over plain UDP (RFC 1035)",
            label=("instance_id", sid),
            value=metrics["request"]["udp"],
        )
        yield _counter(
            "resolver_request_udp4",
            "number of external requests received over IPv4 plain UDP (RFC 1035)",
            label=("instance_id", sid),
            value=metrics["request"]["udp4"],
        )
        yield _counter(
            "resolver_request_udp6",
            "number of external requests received over IPv6 plain UDP (RFC 1035)",
            label=("instance_id", sid),
            value=metrics["request"]["udp6"],
        )
        yield _counter(
            "resolver_request_tcp",
            "number of external requests received over plain TCP (RFC 1035)",
            label=("instance_id", sid),
            value=metrics["request"]["tcp"],
        )
        yield _counter(
            "resolver_request_tcp4",
            "number of external requests received over IPv4 plain TCP (RFC 1035)",
            label=("instance_id", sid),
            value=metrics["request"]["tcp4"],
        )
        yield _counter(
            "resolver_request_tcp6",
            "number of external requests received over IPv6 plain TCP (RFC 1035)",
            label=("instance_id", sid),
            value=metrics["request"]["tcp6"],
        )
        yield _counter(
            "resolver_request_dot",
            "number of external requests received over DNS-over-TLS (RFC 7858)",
            label=("instance_id", sid),
            value=metrics["request"]["dot"],
        )
        yield _counter(
            "resolver_request_dot4",
            "number of external requests received over IPv4 DNS-over-TLS (RFC 7858)",
            label=("instance_id", sid),
            value=metrics["request"]["dot4"],
        )
        yield _counter(
            "resolver_request_dot6",
            "number of external requests received over IPv6 DNS-over-TLS (RFC 7858)",
            label=("instance_id", sid),
            value=metrics["request"]["dot6"],
        )
        yield _counter(
            "resolver_request_doh",
            "number of external requests received over DNS-over-HTTP (RFC 8484)",
            label=("instance_id", sid),
            value=metrics["request"]["doh"],
        )
        yield _counter(
            "resolver_request_doh4",
            "number of external requests received over IPv4 DNS-over-HTTP (RFC 8484)",
            label=("instance_id", sid),
            value=metrics["request"]["doh4"],
        )
        yield _counter(
            "resolver_request_doh6",
            "number of external requests received over IPv6 DNS-over-HTTP (RFC 8484)",
            label=("instance_id", sid),
            value=metrics["request"]["doh6"],
        )
        yield _counter(
            "resolver_request_xdp",
            "number of external requests received over plain UDP via an AF_XDP socket",
            label=("instance_id", sid),
            value=metrics["request"]["xdp"],
        )
        yield _counter(
            "resolver_request_xdp4",
            "number of external requests received over IPv4 plain UDP via an AF_XDP socket",
            label=("instance_id", sid),
            value=metrics["request"]["xdp4"],
        )
        yield _counter(
            "resolver_request_xdp6",
            "number of external requests received over IPv6 plain UDP via an AF_XDP socket",
            label=("instance_id", sid),
            value=metrics["request"]["xdp6"],
        )

        # "answer" metrics
        yield _counter(
            "resolver_answer_total",
            "total number of answered queries",
            label=("instance_id", sid),
            value=metrics["answer"]["total"],
        )
        yield _counter(
            "resolver_answer_cached",
            "number of queries answered from cache",
            label=("instance_id", sid),
            value=metrics["answer"]["cached"],
        )
        yield _counter(
            "resolver_answer_stale",
            "number of queries that utilized stale data",
            label=("instance_id", sid),
            value=metrics["answer"]["stale"],
        )
        yield _counter(
            "resolver_answer_rcode_noerror",
            "number of NOERROR answers",
            label=("instance_id", sid),
            value=metrics["answer"]["noerror"],
        )
        yield _counter(
            "resolver_answer_rcode_nodata",
            "number of NOERROR answers without any data",
            label=("instance_id", sid),
            value=metrics["answer"]["nodata"],
        )
        yield _counter(
            "resolver_answer_rcode_nxdomain",
            "number of NXDOMAIN answers",
            label=("instance_id", sid),
            value=metrics["answer"]["nxdomain"],
        )
        yield _counter(
            "resolver_answer_rcode_servfail",
            "number of SERVFAIL answers",
            label=("instance_id", sid),
            value=metrics["answer"]["servfail"],
        )
        yield _counter(
            "resolver_answer_flag_aa",
            "number of authoritative answers",
            label=("instance_id", sid),
            value=metrics["answer"]["aa"],
        )
        yield _counter(
            "resolver_answer_flag_tc",
            "number of truncated answers",
            label=("instance_id", sid),
            value=metrics["answer"]["tc"],
        )
        yield _counter(
            "resolver_answer_flag_ra",
            "number of answers with recursion available flag",
            label=("instance_id", sid),
            value=metrics["answer"]["ra"],
        )
        yield _counter(
            "resolver_answer_flag_rd",
            "number of recursion desired (in answer!)",
            label=("instance_id", sid),
            value=metrics["answer"]["rd"],
        )
        yield _counter(
            "resolver_answer_flag_ad",
            "number of authentic data (DNSSEC) answers",
            label=("instance_id", sid),
            value=metrics["answer"]["ad"],
        )
        yield _counter(
            "resolver_answer_flag_cd",
            "number of checking disabled (DNSSEC) answers",
            label=("instance_id", sid),
            value=metrics["answer"]["cd"],
        )
        yield _counter(
            "resolver_answer_flag_do",
            "number of DNSSEC answer OK",
            label=("instance_id", sid),
            value=metrics["answer"]["do"],
        )
        yield _counter(
            "resolver_answer_flag_edns0",
            "number of answers with EDNS0 present",
            label=("instance_id", sid),
            value=metrics["answer"]["edns0"],
        )

        # "query" metrics
        yield _counter(
            "resolver_query_edns",
            "number of queries with EDNS present",
            label=("instance_id", sid),
            value=metrics["query"]["edns"],
        )
        yield _counter(
            "resolver_query_dnssec",
            "number of queries with DNSSEC DO=1",
            label=("instance_id", sid),
            value=metrics["query"]["dnssec"],
        )

        # "predict" metrics (optional)
        if "predict" in metrics:
            if "epoch" in metrics["predict"]:
                yield _counter(
                    "resolver_predict_epoch",
                    "current prediction epoch (based on time of day and sampling window)",
                    label=("instance_id", sid),
                    value=metrics["predict"]["epoch"],
                )
            yield _counter(
                "resolver_predict_queue",
                "number of queued queries in current window",
                label=("instance_id", sid),
                value=metrics["predict"]["queue"],
            )
            yield _counter(
                "resolver_predict_learned",
                "number of learned queries in current window",
                label=("instance_id", sid),
                value=metrics["predict"]["learned"],
            )

    def _create_resolver_metrics_loaded_gauge(kresid: "KresID", loaded: bool) -> GaugeMetricFamily:
        return _gauge(
            "resolver_metrics_loaded",
            "0 if metrics from resolver instance were not loaded, otherwise 1",
            label=("instance_id", str(kresid)),
            value=int(loaded),
        )

    class KresPrometheusMetricsCollector:
        def __init__(self, config_store: ConfigStore) -> None:
            self._stats_raw: "Optional[Dict[KresID, object]]" = None
            self._config_store: ConfigStore = config_store
            self._collection_task: "Optional[asyncio.Task[None]]" = None
            self._skip_immediate_collection: bool = False

        def collect(self) -> Generator[Metric, None, None]:
            # schedule new stats collection
            self._trigger_stats_collection()

            # if we have no data, return metrics with information about it and exit
            if self._stats_raw is None:
                for kresid in get_registered_workers_kresids():
                    yield _create_resolver_metrics_loaded_gauge(kresid, False)
                return

            # if we have data, parse them
            for kresid in get_registered_workers_kresids():
                success = False
                try:
                    if kresid in self._stats_raw:
                        metrics = self._stats_raw[kresid]
                        yield from _parse_resolver_metrics(kresid, metrics)
                        success = True
                except KeyError as e:
                    logger.warning(
                        "Failed to load metrics from resolver instance %s: attempted to read missing statistic %s",
                        str(kresid),
                        str(e),
                    )

                yield _create_resolver_metrics_loaded_gauge(kresid, success)

        def describe(self) -> List[Metric]:
            # this function prevents the collector registry from invoking the collect function on startup
            return []

        async def collect_kresd_stats(self, _triggered_from_prometheus_library: bool = False) -> None:
            if self._skip_immediate_collection:
                # this would happen because we are calling this function first manually before stat generation,
                # and once again immediately afterwards caused by the prometheus library's stat collection
                #
                # this is a code made to solve problem with calling async functions from sync methods
                self._skip_immediate_collection = False
                return

            config = self._config_store.get()
            self._stats_raw = await collect_kresd_workers_metrics(config)

            # if this function was not called by the prometheus library and calling collect() is imminent,
            # we should block the next collection cycle as it would be useless
            if not _triggered_from_prometheus_library:
                self._skip_immediate_collection = True

        def _trigger_stats_collection(self) -> None:
            # we are running inside an event loop, but in a synchronous function and that sucks a lot
            # it means that we shouldn't block the event loop by performing a blocking stats collection
            # but it also means that we can't yield to the event loop as this function is synchronous
            # therefore we can only start a new task, but we can't wait for it
            # which causes the metrics to be delayed by one collection pass (not the best, but probably good enough)
            #
            # this issue can be prevented by calling the `collect_kresd_stats()` function manually before entering
            # the Prometheus library. We just have to prevent the library from invoking it again. See the mentioned
            # function for details

            if compat.asyncio.is_event_loop_running():
                # when running, we can schedule the new data collection
                if self._collection_task is not None and not self._collection_task.done():
                    logger.warning("Statistics collection task is still running. Skipping scheduling of a new one!")
                else:
                    self._collection_task = compat.asyncio.create_task(
                        self.collect_kresd_stats(_triggered_from_prometheus_library=True)
                    )

            else:
                # when not running, we can start a new loop (we are not in the manager's main thread)
                compat.asyncio.run(self.collect_kresd_stats(_triggered_from_prometheus_library=True))

    @only_on_real_changes_update(lambda c: c.monitoring.graphite)
    async def _init_graphite_bridge(config: KresConfig, force: bool = False) -> None:
        """
        Starts graphite bridge if required
        """
        global _graphite_bridge
        if config.monitoring.graphite.enable and _graphite_bridge is None:
            logger.info(
                "Starting Graphite metrics exporter for [%s]:%d",
                str(config.monitoring.graphite.host),
                int(config.monitoring.graphite.port),
            )
            _graphite_bridge = GraphiteBridge(
                (str(config.monitoring.graphite.host), int(config.monitoring.graphite.port))
            )
            _graphite_bridge.start(  # type: ignore
                interval=config.monitoring.graphite.interval.seconds(), prefix=str(config.monitoring.graphite.prefix)
            )

    async def _deny_turning_off_graphite_bridge(
        old_config: KresConfig, new_config: KresConfig, force: bool = False
    ) -> Result[None, str]:
        if old_config.monitoring.graphite.enable and not new_config.monitoring.graphite.enable:
            return Result.err(
                "You can't turn off graphite monitoring dynamically. If you really want this feature, please let the developers know."
            )

        if (
            old_config.monitoring.graphite.enable
            and new_config.monitoring.graphite.enable
            and old_config.monitoring.graphite != new_config.monitoring.graphite
        ):
            return Result.err("Changing graphite exporter configuration in runtime is not allowed.")

        return Result.ok(None)


async def init_prometheus(config_store: ConfigStore) -> None:
    """
    Initialize metrics collection. Must be called before any other function from this module.
    """
    if PROMETHEUS_LIB:
        # init and register metrics collector
        global _metrics_collector
        _metrics_collector = KresPrometheusMetricsCollector(config_store)
        REGISTRY.register(_metrics_collector)  # type: ignore

        # register graphite bridge
        await config_store.register_verifier(_deny_turning_off_graphite_bridge)
        await config_store.register_on_change_callback(_init_graphite_bridge)


async def report_prometheus() -> Optional[bytes]:
    if PROMETHEUS_LIB:
        # manually trigger stat collection so that we do not have to wait for it
        if _metrics_collector is not None:
            await _metrics_collector.collect_kresd_stats()
        else:
            raise RuntimeError("Function invoked before initializing the module!")
        return exposition.generate_latest()  # type: ignore
    return None
