#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

import re
from http import HTTPStatus
from typing import Awaitable, Callable, NoReturn

from twisted.internet.defer import Deferred
from twisted.web.resource import Resource

from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
from synapse.http.server import (
    DirectServeHtmlResource,
    DirectServeJsonResource,
    JsonResource,
    OptionsResource,
)
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict
from synapse.util.cancellation import cancellable
from synapse.util.clock import Clock
from synapse.util.duration import Duration

from tests import unittest
from tests.http.server._base import test_disconnect
from tests.server import (
    FakeChannel,
    FakeSite,
    get_clock,
    make_request,
    setup_test_homeserver,
)


class JsonResourceTests(unittest.TestCase):
    def setUp(self) -> None:
        reactor, clock = get_clock()
        self.reactor = reactor
        self.homeserver = setup_test_homeserver(
            cleanup_func=self.addCleanup,
            reactor=self.reactor,
            clock=clock,
        )

    def test_handler_for_request(self) -> None:
        """
        JsonResource.handler_for_request gives correctly decoded URL args to
        the callback, while Twisted will give the raw bytes of URL query
        arguments.
        """
        got_kwargs = {}

        def _callback(
            request: SynapseRequest, **kwargs: object
        ) -> tuple[int, dict[str, object]]:
            got_kwargs.update(kwargs)
            return 200, kwargs

        res = JsonResource(self.homeserver)
        res.register_paths(
            "GET",
            [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")],
            _callback,
            "test_servlet",
        )

        make_request(
            self.reactor,
            FakeSite(res, self.reactor),
            b"GET",
            b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
        )

        self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})

    def test_callback_direct_exception(self) -> None:
        """
        If the web callback raises an uncaught exception, it will be translated
        into a 500.
        """

        def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn:
            raise Exception("boo")

        res = JsonResource(self.homeserver)
        res.register_paths(
            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
        )

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
        )

        self.assertEqual(channel.code, 500)

    def test_callback_indirect_exception(self) -> None:
        """
        If the web callback raises an uncaught exception in a Deferred, it will
        be translated into a 500.
        """

        def _throw(*args: object) -> NoReturn:
            raise Exception("boo")

        def _callback(request: SynapseRequest, **kwargs: object) -> "Deferred[None]":
            d: "Deferred[None]" = Deferred()
            d.addCallback(_throw)
            self.reactor.callLater(0.5, d.callback, True)
            return make_deferred_yieldable(d)

        res = JsonResource(self.homeserver)
        res.register_paths(
            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
        )

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
        )

        self.assertEqual(channel.code, 500)

    def test_callback_synapseerror(self) -> None:
        """
        If the web callback raises a SynapseError, it returns the appropriate
        status code and message set in it.
        """

        def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn:
            raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)

        res = JsonResource(self.homeserver)
        res.register_paths(
            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
        )

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
        )

        self.assertEqual(channel.code, 403)
        self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")

    def test_no_handler(self) -> None:
        """
        If there is no handler to process the request, Synapse will return 400.
        """

        def _callback(request: SynapseRequest, **kwargs: object) -> None:
            """
            Not ever actually called!
            """
            self.fail("shouldn't ever get here")

        res = JsonResource(self.homeserver)
        res.register_paths(
            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
        )

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
        )

        self.assertEqual(channel.code, 404)
        self.assertEqual(channel.json_body["error"], "Unrecognized request")
        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")

    def test_head_request(self) -> None:
        """
        JsonResource.handler_for_request gives correctly decoded URL args to
        the callback, while Twisted will give the raw bytes of URL query
        arguments.
        """

        def _callback(
            request: SynapseRequest, **kwargs: object
        ) -> tuple[int, dict[str, object]]:
            return 200, {"result": True}

        res = JsonResource(self.homeserver)
        res.register_paths(
            "GET",
            [re.compile("^/_matrix/foo$")],
            _callback,
            "test_servlet",
        )

        # The path was registered as GET, but this is a HEAD request.
        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
        )

        self.assertEqual(channel.code, 200)
        self.assertNotIn("body", channel.result)

    def test_content_larger_than_content_length(self) -> None:
        """
        HTTP requests with content size exceeding Content-Length should be rejected with 400.
        """

        def _callback(
            request: SynapseRequest, **kwargs: object
        ) -> tuple[int, JsonDict]:
            return 200, {}

        res = JsonResource(self.homeserver)
        res.register_paths(
            "POST", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
        )

        channel = make_request(
            self.reactor,
            FakeSite(res, self.reactor),
            b"POST",
            b"/_matrix/foo",
            {},
            # Set the `Content-Length` value to be smaller than the actual content size
            custom_headers=[("Content-Length", "1")],
            # The request should disconnect early so don't await the result
            await_result=False,
        )

        self.reactor.advance(0.1)
        self.assertEqual(channel.code, 400)

    def test_content_smaller_than_content_length(self) -> None:
        """
        HTTP requests with content size smaller than Content-Length should be rejected with 400.
        """

        def _callback(
            request: SynapseRequest, **kwargs: object
        ) -> tuple[int, JsonDict]:
            return 200, {}

        res = JsonResource(self.homeserver)
        res.register_paths(
            "POST", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
        )

        channel = make_request(
            self.reactor,
            FakeSite(res, self.reactor),
            b"POST",
            b"/_matrix/foo",
            {},
            # Set the `Content-Length` value to be larger than the actual content size
            custom_headers=[("Content-Length", "10")],
            # The request should disconnect early so don't await the result
            await_result=False,
        )

        self.reactor.advance(0.1)
        self.assertEqual(channel.code, 400)


class OptionsResourceTests(unittest.TestCase):
    def setUp(self) -> None:
        reactor, clock = get_clock()
        self.reactor = reactor
        self.homeserver = setup_test_homeserver(
            cleanup_func=self.addCleanup,
            reactor=self.reactor,
            clock=clock,
        )

        class DummyResource(Resource):
            isLeaf = True

            def render(self, request: SynapseRequest) -> bytes:
                return request.path

        # Setup a resource with some children.
        self.resource = OptionsResource()
        self.resource.putChild(b"res", DummyResource())

    def _make_request(self, method: bytes, path: bytes) -> FakeChannel:
        """Create a request from the method/path and return a channel with the response."""
        # Create a site and query for the resource.
        site = SynapseSite(
            logger_name="test",
            site_tag="site_tag",
            config=parse_listener_def(
                0,
                {
                    "type": "http",
                    "port": 0,
                },
            ),
            resource=self.resource,
            server_version_string="1",
            max_request_body_size=4096,
            reactor=self.reactor,
            hs=self.homeserver,
        )

        # render the request and return the channel
        channel = make_request(self.reactor, site, method, path, shorthand=False)
        return channel

    def _check_cors_standard_headers(self, channel: FakeChannel) -> None:
        # Ensure the correct CORS headers have been added
        # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients
        self.assertEqual(
            channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
            [b"*"],
            "has correct CORS Origin header",
        )
        self.assertEqual(
            channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"),
            [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"],  # HEAD isn't in the spec
            "has correct CORS Methods header",
        )
        self.assertEqual(
            channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"),
            [b"X-Requested-With, Content-Type, Authorization, Date"],
            "has correct CORS Headers header",
        )
        self.assertEqual(
            channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
            [b"Synapse-Trace-Id, Server"],
        )

    def test_unknown_options_request(self) -> None:
        """An OPTIONS requests to an unknown URL still returns 204 No Content."""
        channel = self._make_request(b"OPTIONS", b"/foo/")
        self.assertEqual(channel.code, 204)
        self.assertNotIn("body", channel.result)

        self._check_cors_standard_headers(channel)

    def test_known_options_request(self) -> None:
        """An OPTIONS requests to an known URL still returns 204 No Content."""
        channel = self._make_request(b"OPTIONS", b"/res/")
        self.assertEqual(channel.code, 204)
        self.assertNotIn("body", channel.result)

        self._check_cors_standard_headers(channel)

    def test_unknown_request(self) -> None:
        """A non-OPTIONS request to an unknown URL should 404."""
        channel = self._make_request(b"GET", b"/foo/")
        self.assertEqual(channel.code, 404)

    def test_known_request(self) -> None:
        """A non-OPTIONS request to an known URL should query the proper resource."""
        channel = self._make_request(b"GET", b"/res/")
        self.assertEqual(channel.code, 200)
        self.assertEqual(channel.result["body"], b"/res/")


class WrapHtmlRequestHandlerTests(unittest.TestCase):
    class TestResource(DirectServeHtmlResource):
        callback: Callable[..., Awaitable[None]] | None

        async def _async_render_GET(self, request: SynapseRequest) -> None:
            assert self.callback is not None
            await self.callback(request)

    def setUp(self) -> None:
        reactor, clock = get_clock()
        self.reactor = reactor
        self.clock = clock

    def test_good_response(self) -> None:
        async def callback(request: SynapseRequest) -> None:
            request.write(b"response")
            request.finish()

        res = WrapHtmlRequestHandlerTests.TestResource(clock=self.clock)
        res.callback = callback

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
        )

        self.assertEqual(channel.code, 200)
        body = channel.result["body"]
        self.assertEqual(body, b"response")

    def test_redirect_exception(self) -> None:
        """
        If the callback raises a RedirectException, it is turned into a 30x
        with the right location.
        """

        async def callback(request: SynapseRequest, **kwargs: object) -> None:
            raise RedirectException(b"/look/an/eagle", 301)

        res = WrapHtmlRequestHandlerTests.TestResource(clock=self.clock)
        res.callback = callback

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
        )

        self.assertEqual(channel.code, 301)
        location_headers = channel.headers.getRawHeaders(b"Location", [])
        self.assertEqual(location_headers, [b"/look/an/eagle"])

    def test_redirect_exception_with_cookie(self) -> None:
        """
        If the callback raises a RedirectException which sets a cookie, that is
        returned too
        """

        async def callback(request: SynapseRequest, **kwargs: object) -> NoReturn:
            e = RedirectException(b"/no/over/there", 304)
            e.cookies.append(b"session=yespls")
            raise e

        res = WrapHtmlRequestHandlerTests.TestResource(clock=self.clock)
        res.callback = callback

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
        )

        self.assertEqual(channel.code, 304)
        headers = channel.headers
        location_headers = headers.getRawHeaders(b"Location", [])
        self.assertEqual(location_headers, [b"/no/over/there"])
        cookies_headers = headers.getRawHeaders(b"Set-Cookie", [])
        self.assertEqual(cookies_headers, [b"session=yespls"])

    def test_head_request(self) -> None:
        """A head request should work by being turned into a GET request."""

        async def callback(request: SynapseRequest) -> None:
            request.write(b"response")
            request.finish()

        res = WrapHtmlRequestHandlerTests.TestResource(clock=self.clock)
        res.callback = callback

        channel = make_request(
            self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
        )

        self.assertEqual(channel.code, 200)
        self.assertNotIn("body", channel.result)


class CancellableDirectServeJsonResource(DirectServeJsonResource):
    def __init__(self, clock: Clock):
        super().__init__(clock=clock)
        self.clock = clock

    @cancellable
    async def _async_render_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
        await self.clock.sleep(Duration(seconds=1))
        return HTTPStatus.OK, {"result": True}

    async def _async_render_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]:
        await self.clock.sleep(Duration(seconds=1))
        return HTTPStatus.OK, {"result": True}


class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
    ERROR_TEMPLATE = "{code} {msg}"

    def __init__(self, clock: Clock):
        super().__init__(clock=clock)
        self.clock = clock

    @cancellable
    async def _async_render_GET(self, request: SynapseRequest) -> tuple[int, bytes]:
        await self.clock.sleep(Duration(seconds=1))
        return HTTPStatus.OK, b"ok"

    async def _async_render_POST(self, request: SynapseRequest) -> tuple[int, bytes]:
        await self.clock.sleep(Duration(seconds=1))
        return HTTPStatus.OK, b"ok"


class DirectServeJsonResourceCancellationTests(unittest.TestCase):
    """Tests for `DirectServeJsonResource` cancellation."""

    def setUp(self) -> None:
        reactor, clock = get_clock()
        self.reactor = reactor
        self.resource = CancellableDirectServeJsonResource(clock)
        self.site = FakeSite(self.resource, self.reactor)

    def test_cancellable_disconnect(self) -> None:
        """Test that handlers with the `@cancellable` flag can be cancelled."""
        channel = make_request(
            self.reactor, self.site, "GET", "/sleep", await_result=False
        )
        test_disconnect(
            self.reactor,
            channel,
            expect_cancellation=True,
            expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
        )

    def test_uncancellable_disconnect(self) -> None:
        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
        channel = make_request(
            self.reactor, self.site, "POST", "/sleep", await_result=False
        )
        test_disconnect(
            self.reactor,
            channel,
            expect_cancellation=False,
            expected_body={"result": True},
        )


class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
    """Tests for `DirectServeHtmlResource` cancellation."""

    def setUp(self) -> None:
        reactor, clock = get_clock()
        self.reactor = reactor
        self.resource = CancellableDirectServeHtmlResource(clock)
        self.site = FakeSite(self.resource, self.reactor)

    def test_cancellable_disconnect(self) -> None:
        """Test that handlers with the `@cancellable` flag can be cancelled."""
        channel = make_request(
            self.reactor, self.site, "GET", "/sleep", await_result=False
        )
        test_disconnect(
            self.reactor,
            channel,
            expect_cancellation=True,
            expected_body=b"499 Request cancelled",
        )

    def test_uncancellable_disconnect(self) -> None:
        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
        channel = make_request(
            self.reactor, self.site, "POST", "/sleep", await_result=False
        )
        test_disconnect(
            self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
        )
