1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import json
from http import HTTPStatus
from io import BytesIO
from typing import Tuple, Union
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_json_value_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.cancellation import cancellable
from tests import unittest
def make_request(content: Union[bytes, JsonDict]) -> Mock:
"""Make an object that acts enough like a request."""
request = Mock(spec=["method", "uri", "content"])
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
request.method = bytes("STUB_METHOD", "ascii")
request.uri = bytes("/test_stub_uri", "ascii")
request.content = BytesIO(content)
return request
class TestServletUtils(unittest.TestCase):
def test_parse_json_value(self) -> None:
"""Basic tests for parse_json_value_from_request."""
# Test round-tripping.
obj = {"foo": 1}
result1 = parse_json_value_from_request(make_request(obj))
self.assertEqual(result1, obj)
# Results don't have to be objects.
result2 = parse_json_value_from_request(make_request(b'["foo"]'))
self.assertEqual(result2, ["foo"])
# Test empty.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b""))
result3 = parse_json_value_from_request(
make_request(b""), allow_empty_body=True
)
self.assertIsNone(result3)
# Invalid UTF-8.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b"\xff\x00"))
# Invalid JSON.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b"foo"))
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
def test_parse_json_object(self) -> None:
"""Basic tests for parse_json_object_from_request."""
# Test empty.
result = parse_json_object_from_request(
make_request(b""), allow_empty_body=True
)
self.assertEqual(result, {})
# Test not an object
with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]'))
class CancellableRestServlet(RestServlet):
"""A `RestServlet` with a mix of cancellable and uncancellable handlers."""
PATTERNS = client_patterns("/sleep$")
def __init__(self, hs: HomeServer):
super().__init__()
self.clock = hs.get_clock()
@cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""
servlets = [
lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
]
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
from tests.http.server._base import test_disconnect
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
)
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
from tests.http.server._base import test_disconnect
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
expected_body={"result": True},
)
|