#!/usr/bin/env python
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2022-2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

import argparse
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Iterable, Pattern

import yaml

from synapse.config.homeserver import HomeServerConfig
from synapse.federation.transport.server import (
    TransportLayerServer,
    register_servlets as register_federation_servlets,
)
from synapse.http.server import HttpServer, ServletCallback
from synapse.rest import ClientRestResource
from synapse.rest.key.v2 import RemoteKey
from synapse.server import HomeServer
from synapse.storage import DataStore

logger = logging.getLogger("generate_workers_map")


class MockHomeserver(HomeServer):
    DATASTORE_CLASS = DataStore

    def __init__(self, config: HomeServerConfig, worker_app: str | None) -> None:
        super().__init__(config.server.server_name, config=config)
        self.config.worker.worker_app = worker_app


GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)")


@dataclass
class EndpointDescription:
    """
    Describes an endpoint and how it should be routed.
    """

    # The servlet class that handles this endpoint
    servlet_class: object

    # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
    # class.
    category: str | None

    # TODO:
    #  - does it need to be routed based on a stream writer config?
    #  - does it benefit from any optimised, but optional, routing?
    #  - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does
    #    it go in?


class EnumerationResource(HttpServer):
    """
    Accepts servlet registrations for the purposes of building up a description of
    all endpoints.
    """

    def __init__(self, is_worker: bool) -> None:
        self.registrations: dict[tuple[str, str], EndpointDescription] = {}
        self._is_worker = is_worker

    def register_paths(
        self,
        method: str,
        path_patterns: Iterable[Pattern],
        callback: ServletCallback,
        servlet_classname: str,
    ) -> None:
        # federation servlet callbacks are wrapped, so unwrap them.
        callback = getattr(callback, "__wrapped__", callback)

        # fish out the servlet class
        servlet_class = callback.__self__.__class__  # type: ignore

        if self._is_worker and method in getattr(
            servlet_class, "WORKERS_DENIED_METHODS", ()
        ):
            # This endpoint would cause an error if called on a worker, so pretend it
            # was never registered!
            return

        sd = EndpointDescription(
            servlet_class=servlet_class,
            category=getattr(servlet_class, "CATEGORY", None),
        )

        for pat in path_patterns:
            self.registrations[(method, pat.pattern)] = sd


def get_registered_paths_for_hs(
    hs: HomeServer,
) -> dict[tuple[str, str], EndpointDescription]:
    """
    Given a homeserver, get all registered endpoints and their descriptions.
    """

    enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None)
    ClientRestResource.register_servlets(enumerator, hs)
    federation_server = TransportLayerServer(hs)

    # we can't use `federation_server.register_servlets` but this line does the
    # same thing, only it uses this enumerator
    register_federation_servlets(
        federation_server.hs,
        resource=enumerator,
        ratelimiter=federation_server.ratelimiter,
        authenticator=federation_server.authenticator,
        servlet_groups=federation_server.servlet_groups,
    )

    # the key server endpoints are separate again
    RemoteKey(hs).register(enumerator)

    return enumerator.registrations


def get_registered_paths_for_default(
    worker_app: str | None, base_config: HomeServerConfig
) -> dict[tuple[str, str], EndpointDescription]:
    """
    Given the name of a worker application and a base homeserver configuration,
    returns:

        Dict from (method, path) to EndpointDescription

    TODO Don't require passing in a config
    """

    hs = MockHomeserver(base_config, worker_app)

    # TODO We only do this to avoid an error, but don't need the database etc
    hs.setup()
    registered_paths = get_registered_paths_for_hs(hs)
    # NOTE: a more robust implementation would properly shutdown/cleanup each server
    # to avoid resource buildup.
    # However, the call to `shutdown` is `async` so it would require additional complexity here.
    # We are intentionally skipping this cleanup because this is a short-lived, one-off
    # utility script where the simpler approach is sufficient and we shouldn't run into
    # any resource buildup issues.

    return registered_paths


def elide_http_methods_if_unconflicting(
    registrations: dict[tuple[str, str], EndpointDescription],
    all_possible_registrations: dict[tuple[str, str], EndpointDescription],
) -> dict[tuple[str, str], EndpointDescription]:
    """
    Elides HTTP methods (by replacing them with `*`) if all possible registered methods
    can be handled by the worker whose registration map is `registrations`.

    i.e. the only endpoints left with methods (other than `*`) should be the ones where
    the worker can't handle all possible methods for that path.
    """

    def paths_to_methods_dict(
        methods_and_paths: Iterable[tuple[str, str]],
    ) -> dict[str, set[str]]:
        """
        Given (method, path) pairs, produces a dict from path to set of methods
        available at that path.
        """
        result: dict[str, set[str]] = {}
        for method, path in methods_and_paths:
            result.setdefault(path, set()).add(method)
        return result

    all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations)
    reg_methods = paths_to_methods_dict(registrations)

    output = {}

    for path, handleable_methods in reg_methods.items():
        if handleable_methods == all_possible_reg_methods[path]:
            any_method = next(iter(handleable_methods))
            # TODO This assumes that all methods have the same servlet.
            #      I suppose that's possibly dubious?
            output[("*", path)] = registrations[(any_method, path)]
        else:
            for method in handleable_methods:
                output[(method, path)] = registrations[(method, path)]

    return output


def simplify_path_regexes(
    registrations: dict[tuple[str, str], EndpointDescription],
) -> dict[tuple[str, str], EndpointDescription]:
    """
    Simplify all the path regexes for the dict of endpoint descriptions,
    so that we don't use the Python-specific regex extensions
    (and also to remove needlessly specific detail).
    """

    def simplify_path_regex(path: str) -> str:
        """
        Given a regex pattern, replaces all named capturing groups (e.g. `(?P<blah>xyz)`)
        with a simpler version available in more common regex dialects (e.g. `.*`).
        """

        # TODO it's hard to choose between these two;
        #      `.*` is a vague simplification
        # return GROUP_PATTERN.sub(r"\1", path)
        return GROUP_PATTERN.sub(r".*", path)

    return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()}


def main() -> None:
    parser = argparse.ArgumentParser(
        description=(
            "Updates a synapse database to the latest schema and optionally runs background updates"
            " on it."
        )
    )
    parser.add_argument("-v", action="store_true")
    parser.add_argument(
        "--config-path",
        type=argparse.FileType("r"),
        required=True,
        help="Synapse configuration file",
    )

    args = parser.parse_args()

    # TODO
    # logging.basicConfig(**logging_config)

    # Load, process and sanity-check the config.
    hs_config = yaml.safe_load(args.config_path)

    config = HomeServerConfig()
    config.parse_config_dict(hs_config, "", "")

    master_paths = get_registered_paths_for_default(None, config)
    worker_paths = get_registered_paths_for_default(
        "synapse.app.generic_worker", config
    )

    all_paths = {**master_paths, **worker_paths}

    elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths)
    elide_http_methods_if_unconflicting(master_paths, all_paths)

    # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT

    categories_to_methods_and_paths: dict[
        str | None, dict[tuple[str, str], EndpointDescription]
    ] = defaultdict(dict)

    for (method, path), desc in elided_worker_paths.items():
        categories_to_methods_and_paths[desc.category][method, path] = desc

    for category, contents in categories_to_methods_and_paths.items():
        print_category(category, contents)


def print_category(
    category_name: str | None,
    elided_worker_paths: dict[tuple[str, str], EndpointDescription],
) -> None:
    """
    Prints out a category, in documentation page style.

    Example:
    ```
    # Category name
    /path/xyz

    GET /path/abc
    ```
    """

    if category_name:
        print(f"# {category_name}")
    else:
        print("# (Uncategorised requests)")

    for ln in sorted(
        p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*"
    ):
        print(ln)
    print()
    for ln in sorted(
        f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*"
    ):
        print(ln)
    print()


if __name__ == "__main__":
    main()
