# Copyright 2020 by Federico Caselli
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Inspect utilities for falcon applications."""

from __future__ import annotations

from functools import partial
import inspect
from typing import (
    Any,
    Callable,
    cast,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

from falcon import app_helpers
from falcon.app import App
from falcon.routing import CompiledRouter
from falcon.routing.compiled import CompiledRouterNode


def inspect_app(app: App) -> AppInfo:
    """Inspects an application.

    Args:
        app (falcon.App): The application to inspect. Works with both
            :class:`falcon.App` and :class:`falcon.asgi.App`.

    Returns:
        AppInfo: The information regarding the application. Call
        :meth:`~.AppInfo.to_string` on the result to obtain a human-friendly
        representation.
    """
    routes = inspect_routes(app)
    static = inspect_static_routes(app)
    sinks = inspect_sinks(app)
    error_handlers = inspect_error_handlers(app)
    middleware = inspect_middleware(app)
    return AppInfo(routes, middleware, static, sinks, error_handlers, app._ASGI)


def inspect_routes(app: App) -> List[RouteInfo]:
    """Inspects the routes of an application.

    Args:
        app (falcon.App): The application to inspect. Works with both
            :class:`falcon.App` and :class:`falcon.asgi.App`.

    Returns:
        List[RouteInfo]: A list of route descriptions for the application.
    """
    router = app._router

    inspect_function = _supported_routers.get(type(router))
    if inspect_function is None:
        raise TypeError(
            'Unsupported router class {}. Use "register_router" '
            'to register a function that can inspect the router '
            'used by the provided application'.format(type(router))
        )
    return inspect_function(router)


def register_router(
    router_class: Type,
) -> Callable[..., Callable[..., List[RouteInfo]]]:
    """Register a function to inspect a particular router.

    This decorator registers a new function for a custom router
    class, so that it can be inspected with the function
    :func:`.inspect_routes`.
    An inspection function takes the router instance used by the
    application and returns a list of :class:`.RouteInfo`. Eg::

        @register_router(MyRouterClass)
        def inspect_my_router(router):
            return [RouteInfo('foo', 'bar', '/path/to/foo.py:42', [])]

    Args:
        router_class (Type): The router class to register. If
            already registered an error will be raised.
    """

    def wraps(fn: Callable[..., List[RouteInfo]]) -> Callable[..., List[RouteInfo]]:
        if router_class in _supported_routers:
            raise ValueError(
                'Another function is already registered for the router {}'.format(
                    router_class
                )
            )
        _supported_routers[router_class] = fn
        return fn

    return wraps


_supported_routers: Dict[Type, Callable[..., Any]] = {}


def inspect_static_routes(app: App) -> 'List[StaticRouteInfo]':
    """Inspects the static routes of an application.

    Args:
        app (falcon.App): The application to inspect. Works with both
            :class:`falcon.App` and :class:`falcon.asgi.App`.

    Returns:
        List[StaticRouteInfo]: A list of static routes that have
        been added to the application.
    """
    routes = []
    for sr, _, _ in app._static_routes:
        info = StaticRouteInfo(sr._prefix, sr._directory, sr._fallback_filename)
        routes.append(info)
    return routes


def inspect_sinks(app: App) -> 'List[SinkInfo]':
    """Inspects the sinks of an application.

    Args:
        app (falcon.App): The application to inspect. Works with both
            :class:`falcon.App` and :class:`falcon.asgi.App`.

    Returns:
        List[SinkInfo]: A list of sinks used by the application.
    """
    sinks = []
    for prefix, sink, _ in app._sinks:
        source_info, name = _get_source_info_and_name(sink)
        assert source_info
        info = SinkInfo(prefix.pattern, name, source_info)
        sinks.append(info)
    return sinks


def inspect_error_handlers(app: App) -> 'List[ErrorHandlerInfo]':
    """Inspects the error handlers of an application.

    Args:
        app (falcon.App): The application to inspect. Works with both
            :class:`falcon.App` and :class:`falcon.asgi.App`.

    Returns:
        List[ErrorHandlerInfo]: A list of error handlers used by the
        application.
    """
    errors = []
    for exc, fn in app._error_handlers.items():
        source_info, name = _get_source_info_and_name(fn)
        assert source_info
        info = ErrorHandlerInfo(exc.__name__, name, source_info, _is_internal(fn))
        errors.append(info)
    return errors


def inspect_middleware(app: App) -> 'MiddlewareInfo':
    """Inspects the middleware components of an application.

    Args:
        app (falcon.App): The application to inspect. Works with both
            :class:`falcon.App` and :class:`falcon.asgi.App`.

    Returns:
        MiddlewareInfo: Information about the app's middleware components.
    """
    types_ = app_helpers.prepare_middleware(app._unprepared_middleware, True, app._ASGI)

    type_infos = []
    for stack in types_:
        current = []
        for method in stack:
            _, name = _get_source_info_and_name(method)
            cls = type(method.__self__)  # type: ignore[union-attr]
            _, cls_name = _get_source_info_and_name(cls)
            current.append(MiddlewareTreeItemInfo(name, cls_name))
        type_infos.append(current)
    middlewareTree = MiddlewareTreeInfo(*type_infos)

    middlewareClasses = []
    names = 'Process request', 'Process resource', 'Process response'
    for m in app._unprepared_middleware:
        fns = app_helpers.prepare_middleware([m], True, app._ASGI)
        class_source_info, cls_name = _get_source_info_and_name(type(m))
        methods = []
        for method, name in zip(fns, names):  # type: ignore[assignment]
            if method:
                real_func = method[0]  # type: ignore[index]
                source_info = _get_source_info(real_func)
                assert source_info
                methods.append(MiddlewareMethodInfo(real_func.__name__, source_info))  # type: ignore[union-attr]
        assert class_source_info
        m_info = MiddlewareClassInfo(cls_name, class_source_info, methods)
        middlewareClasses.append(m_info)

    return MiddlewareInfo(
        middlewareTree, middlewareClasses, app._independent_middleware
    )


@register_router(CompiledRouter)
def inspect_compiled_router(router: CompiledRouter) -> 'List[RouteInfo]':
    """Walk an instance of :class:`~.CompiledRouter` to return a list of defined routes.

    Default route inspector for CompiledRouter.

    Args:
        router (CompiledRouter): The router to inspect.

    Returns:
        List[RouteInfo]: A list of :class:`~.RouteInfo`.
    """

    def _traverse(roots: List[CompiledRouterNode], parent: str) -> None:
        for root in roots:
            path = parent + '/' + root.raw_segment
            if root.resource is not None:
                methods = []
                if root.method_map:
                    for method, func in root.method_map.items():
                        if isinstance(func, partial):
                            real_func = func.func
                        else:
                            real_func = func

                        source_info = _get_source_info(real_func)
                        internal = _is_internal(real_func)
                        assert source_info, (
                            'This is for type checking only, as here source '
                            'info will always be a string'
                        )
                        method_info = RouteMethodInfo(
                            method, source_info, real_func.__name__, internal
                        )
                        methods.append(method_info)
                source_info, class_name = _get_source_info_and_name(root.resource)
                assert source_info
                route_info = RouteInfo(path, class_name, source_info, methods)
                routes.append(route_info)

            if root.children:
                _traverse(root.children, path)

    routes = []  # type: List[RouteInfo]
    _traverse(router._roots, '')
    return routes


# ------------------------------------------------------------------------
# Inspection classes
# ------------------------------------------------------------------------


class _Traversable:
    __visit_name__ = 'N/A'

    def to_string(self, verbose: bool = False, internal: bool = False) -> str:
        """Return a string representation of this class.

        Args:
            verbose (bool, optional): Adds more information. Defaults to False.
            internal (bool, optional): Also include internal route methods
                and error handlers added by the framework. Defaults to
                ``False``.

        Returns:
            str: string representation of this class.
        """
        return StringVisitor(verbose, internal).process(self)

    def __repr__(self) -> str:
        return self.to_string()


class RouteMethodInfo(_Traversable):
    """Describes a responder method.

    Args:
        method (str): The HTTP method of this responder.
        source_info (str): The source path of this function.
        function_name (str): Name of the function.
        internal (bool): Whether or not this was a default responder added
            by the framework.

    """

    __visit_name__ = 'route_method'

    suffix: str
    """The suffix of this route function.

    This is set to an empty string when the function has no suffix.
    """

    def __init__(
        self, method: str, source_info: str, function_name: str, internal: bool
    ):
        self.method = method
        self.source_info = source_info
        self.function_name = function_name
        self.internal = internal
        # NOTE(CaselIT): internal falcon names do not start with 'on'
        # and do not have suffix
        if function_name.startswith('on'):
            self.suffix = '_'.join(function_name.split('_')[2:])
        else:
            self.suffix = ''


class RouteInfo(_Traversable):
    """Describes a route.

    Args:
        path (str): The path of this route.
        class_name (str): The class name of the responder of this route.
        source_info (str): The source path where this responder was defined.
        methods (List[RouteMethodInfo]): List of methods defined in the route.
    """

    __visit_name__ = 'route'

    def __init__(
        self,
        path: str,
        class_name: str,
        source_info: str,
        methods: List[RouteMethodInfo],
    ):
        self.path = path
        self.class_name = class_name
        self.source_info = source_info
        self.methods = methods


class StaticRouteInfo(_Traversable):
    """Describes a static route.

    Args:
        path (str): The prefix of the static route.
        directory (str): The directory for the static route.
        fallback_filename (str or None): Fallback filename to serve.
    """

    __visit_name__ = 'static_route'

    def __init__(self, prefix: str, directory: str, fallback_filename: Optional[str]):
        self.prefix = prefix
        self.directory = directory
        self.fallback_filename = fallback_filename


class SinkInfo(_Traversable):
    """Describes a sink.

    Args:
        prefix (str): The prefix of the sink.
        name (str): The name of the sink function or class.
        source_info (str): The source path where this sink was defined.
    """

    __visit_name__ = 'sink'

    def __init__(self, prefix: str, name: str, source_info: str):
        self.prefix = prefix
        self.name = name
        self.source_info = source_info


class ErrorHandlerInfo(_Traversable):
    """Describes an error handler.

    Args:
        error (name): The name of the error type.
        name (str): The name of the handler.
        source_info (str): The source path where this error handler was defined.
        internal (bool): Whether or not this is a default error handler added by
            the framework.
    """

    __visit_name__ = 'error_handler'

    def __init__(self, error: str, name: str, source_info: str, internal: bool):
        self.error = error
        self.name = name
        self.source_info = source_info
        self.internal = internal


class MiddlewareMethodInfo(_Traversable):
    """Describes a middleware method.

    Args:
        function_name (str): Name of the method.
        source_info (str): The source path of the method.
    """

    __visit_name__ = 'middleware_method'

    def __init__(self, function_name: str, source_info: str):
        self.function_name = function_name
        self.source_info = source_info
        self.internal = False  # added for compatibility with RouteMethodInfo


class MiddlewareClassInfo(_Traversable):
    """Describes a middleware class.

    Args:
        name (str): The name of the middleware class.
        source_info (str): The source path where the middleware was defined.
        methods (List[MiddlewareMethodInfo]): List of method defined by the
            middleware class.
    """

    __visit_name__ = 'middleware_class'

    def __init__(
        self, name: str, source_info: str, methods: List[MiddlewareMethodInfo]
    ):
        self.name = name
        self.source_info = source_info
        self.methods = methods


class MiddlewareTreeItemInfo(_Traversable):
    """Describes a middleware tree entry.

    Args:
        name (str): The name of the method.
        class_name (str): The class name of the method.
    """

    __visit_name__ = 'middleware_tree_item'

    _symbols = {
        'process_request': '→',
        'process_resource': '↣',
        'process_response': '↢',
    }

    def __init__(self, name: str, class_name: str):
        self.name = name
        self.class_name = class_name


class MiddlewareTreeInfo(_Traversable):
    """Describes the middleware methods used by the app.

    Args:
        request (List[MiddlewareTreeItemInfo]): The `process_request` methods.
        resource (List[MiddlewareTreeItemInfo]): The `process_resource` methods.
        response (List[MiddlewareTreeItemInfo]): The `process_response` methods.
    """

    __visit_name__ = 'middleware_tree'

    def __init__(
        self,
        request: List[MiddlewareTreeItemInfo],
        resource: List[MiddlewareTreeItemInfo],
        response: List[MiddlewareTreeItemInfo],
    ):
        self.request = request
        self.resource = resource
        self.response = response


class MiddlewareInfo(_Traversable):
    """Describes the middleware of the app.

    Args:
        middlewareTree (MiddlewareTreeInfo): The middleware tree of the app.
        middlewareClasses (List[MiddlewareClassInfo]): The middleware classes of
            the app.
        independent (bool): Whether or not the middleware components are executed
            independently.

    """

    __visit_name__ = 'middleware'

    independent_text: str
    """Text created from the `independent` arg."""

    def __init__(
        self,
        middleware_tree: MiddlewareTreeInfo,
        middleware_classes: List[MiddlewareClassInfo],
        independent: bool,
    ):
        self.middleware_tree = middleware_tree
        self.middleware_classes = middleware_classes
        self.independent = independent

        if independent:
            self.independent_text = 'Middleware are independent'
        else:
            self.independent_text = 'Middleware are dependent'


class AppInfo(_Traversable):
    """Describes an application.

    Args:
        routes (List[RouteInfo]): The routes of the application.
        middleware (MiddlewareInfo): The middleware information in the application.
        static_routes (List[StaticRouteInfo]): The static routes of this application.
        sinks (List[SinkInfo]): The sinks of this application.
        error_handlers (List[ErrorHandlerInfo]): The error handlers of this application.
        asgi (bool): Whether or not this is an ASGI application.
    """

    __visit_name__ = 'app'

    def __init__(
        self,
        routes: List[RouteInfo],
        middleware: MiddlewareInfo,
        static_routes: List[StaticRouteInfo],
        sinks: List[SinkInfo],
        error_handlers: List[ErrorHandlerInfo],
        asgi: bool,
    ):
        self.routes = routes
        self.middleware = middleware
        self.static_routes = static_routes
        self.sinks = sinks
        self.error_handlers = error_handlers
        self.asgi = asgi

    def to_string(
        self, verbose: bool = False, internal: bool = False, name: str = ''
    ) -> str:
        """Return a string representation of this class.

        Args:
            verbose (bool, optional): Adds more information. Defaults to False.
            internal (bool, optional): Also include internal falcon route methods
                and error handlers. Defaults to ``False``.
            name (str, optional): The name of the application, to be output at the
                beginning of the text. Defaults to ``'Falcon App'``.
        Returns:
            str: A string representation of the application.
        """
        return StringVisitor(verbose, internal, name).process(self)


# ------------------------------------------------------------------------
# Visitor classes
# ------------------------------------------------------------------------


class InspectVisitor:
    """Base visitor class that implements the `process` method.

    Subclasses must implement ``visit_<name>`` methods for each supported class.
    """

    def process(self, instance: _Traversable) -> str:
        """Process the instance, by calling the appropriate visit method.

        Uses the `__visit_name__` attribute of the `instance` to obtain the method
        to use.

        Args:
            instance (_Traversable): The instance to process.
        """
        try:
            return getattr(self, 'visit_{}'.format(instance.__visit_name__))(instance)
        except AttributeError as e:
            raise RuntimeError(
                'This visitor does not support {}'.format(type(instance))
            ) from e


class StringVisitor(InspectVisitor):
    """Visitor that returns a string representation of the info class.

    This is used automatically by calling ``to_string()`` on the info class.
    It can also be used directly by calling ``StringVisitor.process(info_instance)``.

    Args:
        verbose (bool, optional): Adds more information. Defaults to ``False``.
        internal (bool, optional): Also include internal route methods
            and error handlers added by the framework. Defaults to ``False``.
        name (str, optional): The name of the application, to be output at the
            beginning of the text. Defaults to ``'Falcon App'``.
    """

    def __init__(
        self, verbose: bool = False, internal: bool = False, name: str = ''
    ) -> None:
        self.verbose = verbose
        self.internal = internal
        self.name = name
        self.indent = 0

    @property
    def tab(self) -> str:
        """Get the current tabulation."""
        return ' ' * self.indent

    def visit_route_method(self, route_method: RouteMethodInfo) -> str:
        """Visit a RouteMethodInfo instance. Usually called by `process`."""
        text = '{0.method} - {0.function_name}'.format(route_method)
        if self.verbose:
            text += ' ({0.source_info})'.format(route_method)
        return text

    def _methods_to_string(
        self, methods: Union[List[RouteMethodInfo], List[MiddlewareMethodInfo]]
    ) -> str:
        """Return a string from the list of methods."""
        tab = self.tab + ' ' * 3
        filtered_methods = _filter_internal(methods, self.internal)
        if not filtered_methods:
            return ''
        text_list = [self.process(m) for m in filtered_methods]
        method_text = ['{}├── {}'.format(tab, m) for m in text_list[:-1]]
        method_text += ['{}└── {}'.format(tab, m) for m in text_list[-1:]]
        return '\n'.join(method_text)

    def visit_route(self, route: RouteInfo) -> str:
        """Visit a RouteInfo instance. Usually called by `process`."""
        text = '{0}⇒ {1.path} - {1.class_name}'.format(self.tab, route)
        if self.verbose:
            text += ' ({0.source_info})'.format(route)

        method_text = self._methods_to_string(route.methods)
        if not method_text:
            return text

        return '{}:\n{}'.format(text, method_text)

    def visit_static_route(self, static_route: StaticRouteInfo) -> str:
        """Visit a StaticRouteInfo instance. Usually called by `process`."""
        text = '{0}↦ {1.prefix} {1.directory}'.format(self.tab, static_route)
        if static_route.fallback_filename:
            text += ' [{0.fallback_filename}]'.format(static_route)
        return text

    def visit_sink(self, sink: SinkInfo) -> str:
        """Visit a SinkInfo instance. Usually called by `process`."""
        text = '{0}⇥ {1.prefix} {1.name}'.format(self.tab, sink)
        if self.verbose:
            text += ' ({0.source_info})'.format(sink)
        return text

    def visit_error_handler(self, error_handler: ErrorHandlerInfo) -> str:
        """Visit a ErrorHandlerInfo instance. Usually called by `process`."""
        text = '{0}⇜ {1.error} {1.name}'.format(self.tab, error_handler)
        if self.verbose:
            text += ' ({0.source_info})'.format(error_handler)
        return text

    def visit_middleware_method(self, middleware_method: MiddlewareMethodInfo) -> str:
        """Visit a MiddlewareMethodInfo instance. Usually called by `process`."""
        text = '{0.function_name}'.format(middleware_method)
        if self.verbose:
            text += ' ({0.source_info})'.format(middleware_method)
        return text

    def visit_middleware_class(self, middleware_class: MiddlewareClassInfo) -> str:
        """Visit a ErrorHandlerInfo instance. Usually called by `process`."""
        text = '{0}↣ {1.name}'.format(self.tab, middleware_class)
        if self.verbose:
            text += ' ({0.source_info})'.format(middleware_class)

        method_text = self._methods_to_string(middleware_class.methods)
        if not method_text:
            return text

        return '{}:\n{}'.format(text, method_text)

    def visit_middleware_tree_item(self, mti: MiddlewareTreeItemInfo) -> str:
        """Visit a MiddlewareTreeItemInfo instance. Usually called by `process`."""
        symbol = mti._symbols.get(mti.name, '→')
        return '{0}{1} {2.class_name}.{2.name}'.format(self.tab, symbol, mti)

    def visit_middleware_tree(self, m_tree: MiddlewareTreeInfo) -> str:
        """Visit a MiddlewareTreeInfo instance. Usually called by `process`."""
        before = len(m_tree.request) + len(m_tree.resource)
        after = len(m_tree.response)

        if before + after == 0:
            return ''

        each = 2
        initial = self.indent
        if after > before:
            self.indent += each * (after - before)

        text = []
        for r in m_tree.request:
            text.append(self.process(r))
            self.indent += each
        if text:
            text.append('')
        for r in m_tree.resource:
            text.append(self.process(r))
            self.indent += each

        if m_tree.resource or not text:
            text.append('')
        self.indent += each
        text.append('{}├── Process route responder'.format(self.tab))
        self.indent -= each
        if m_tree.response:
            text.append('')

        for r in m_tree.response:
            self.indent -= each
            text.append(self.process(r))

        self.indent = initial
        return '\n'.join(text)

    def visit_middleware(self, middleware: MiddlewareInfo) -> str:
        """Visit a MiddlewareInfo instance. Usually called by `process`."""
        text = self.process(middleware.middleware_tree)
        if self.verbose:
            self.indent += 4
            m_text = '\n'.join(self.process(m) for m in middleware.middleware_classes)
            self.indent -= 4
            if m_text:
                text += '\n{}- Middleware classes:\n{}'.format(self.tab, m_text)

        return text

    def visit_app(self, app: AppInfo) -> str:
        """Visit a AppInfo instance. Usually called by `process`."""

        type_ = 'ASGI' if app.asgi else 'WSGI'
        self.indent = 4
        text = '{} ({})'.format(self.name or 'Falcon App', type_)

        if app.routes:
            routes = '\n'.join(self.process(r) for r in app.routes)
            text += '\n• Routes:\n{}'.format(routes)

        middleware_text = self.process(app.middleware)
        if middleware_text:
            text += '\n• Middleware ({}):\n{}'.format(
                app.middleware.independent_text, middleware_text
            )

        if app.static_routes:
            static_routes = '\n'.join(self.process(sr) for sr in app.static_routes)
            text += '\n• Static routes:\n{}'.format(static_routes)

        if app.sinks:
            sinks = '\n'.join(self.process(s) for s in app.sinks)
            text += '\n• Sinks:\n{}'.format(sinks)

        errors = _filter_internal(app.error_handlers, self.internal)
        if errors:
            errs = '\n'.join(self.process(e) for e in errors)
            text += '\n• Error handlers:\n{}'.format(errs)

        return text


# ------------------------------------------------------------------------
# Helpers functions
# ------------------------------------------------------------------------


def _get_source_info(
    obj: Any, default: Optional[str] = '[unknown file]'
) -> Optional[str]:
    """Try to get the definition file and line of obj.

    Return default on error.
    """
    try:
        source_file = inspect.getsourcefile(obj)
        source_lines = inspect.findsource(obj)
        source_info = '{}:{}'.format(source_file, source_lines[1])
    except Exception:
        # NOTE(vytas): If Falcon is cythonized, all default
        # responders coming from cythonized modules will
        # appear as built-in functions, and raise a
        # TypeError when trying to locate the source file.
        return default
    return source_info


def _get_source_info_and_name(obj: Any) -> Tuple[Optional[str], str]:
    """Attempt to get the definition file and line of obj and its name."""
    source_info = _get_source_info(obj, None)
    if source_info is None:
        # NOTE(caselit): a class instances return None. Try the type
        source_info = _get_source_info(type(obj))
    name = getattr(obj, '__name__', None)
    if name is None:
        name = getattr(type(obj), '__name__', '[unknown]')
    name = cast(str, name)
    return source_info, name


def _is_internal(obj: Any) -> bool:
    """Check if the module of the object is a falcon module."""
    module = inspect.getmodule(obj)
    if module:
        return module.__name__.startswith('falcon.')
    return False


def _filter_internal(
    iterable: Union[
        Iterable[RouteMethodInfo],
        Iterable[ErrorHandlerInfo],
        Iterable[MiddlewareMethodInfo],
    ],
    return_internal: bool,
) -> Union[Iterable[_Traversable], List[_Traversable]]:
    """Filter the internal elements of an iterable."""
    if return_internal:
        return iterable
    return [el for el in iterable if not el.internal]
