File: test_test_client.py

package info (click to toggle)
litestar 2.19.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 12,500 kB
  • sloc: python: 70,169; makefile: 254; javascript: 105; sh: 60
file content (328 lines) | stat: -rw-r--r-- 12,240 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from queue import Empty
from typing import TYPE_CHECKING, Callable, Dict, NoReturn, Optional, Union, cast

from _pytest.fixtures import FixtureRequest

from litestar import Controller, WebSocket, delete, head, patch, put, websocket
from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT
from litestar.testing import AsyncTestClient, WebSocketTestSession, create_async_test_client, create_test_client

if TYPE_CHECKING:
    from litestar.middleware.session.base import BaseBackendConfig
    from litestar.types import (
        AnyIOBackend,
        HTTPResponseBodyEvent,
        HTTPResponseStartEvent,
        Receive,
        Scope,
        Send,
    )

from typing import Any, Type

import pytest

from litestar import Litestar, Request, get, post
from litestar.stores.base import Store
from litestar.testing import TestClient
from litestar.utils.helpers import get_exception_group
from tests.helpers import maybe_async, maybe_async_cm

_ExceptionGroup = get_exception_group()

AnyTestClient = Union[TestClient, AsyncTestClient]


async def mock_asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> None:
    pass


@pytest.fixture(params=[AsyncTestClient, TestClient])
def test_client_cls(request: FixtureRequest) -> Type[AnyTestClient]:
    return cast(Type[AnyTestClient], request.param)


@pytest.mark.parametrize(
    "anyio_backend",
    [
        pytest.param("asyncio"),
        pytest.param("trio", marks=pytest.mark.xfail(reason="Known issue with trio backend", strict=False)),
    ],
)
@pytest.mark.parametrize("with_domain", [False, True])
async def test_test_client_set_session_data(
    with_domain: bool,
    anyio_backend: str,
    session_backend_config: "BaseBackendConfig",
    test_client_backend: "AnyIOBackend",
    test_client_cls: Type[AnyTestClient],
) -> None:
    session_data = {"foo": "bar"}

    if with_domain:
        session_backend_config.domain = "testserver.local"

    @get(path="/test")
    async def get_session_data(request: Request) -> Dict[str, Any]:
        return request.session

    app = Litestar(route_handlers=[get_session_data], middleware=[session_backend_config.middleware])

    async with maybe_async_cm(
        test_client_cls(app=app, session_config=session_backend_config, backend=test_client_backend)  # pyright: ignore
    ) as client:
        await maybe_async(client.set_session_data(session_data))  # type: ignore[attr-defined]
        assert session_data == (await maybe_async(client.get("/test"))).json()  # type: ignore[attr-defined]


@pytest.mark.parametrize(
    "anyio_backend",
    [
        pytest.param("asyncio"),
        pytest.param("trio", marks=pytest.mark.xfail(reason="Known issue with trio backend", strict=False)),
    ],
)
@pytest.mark.parametrize("with_domain", [True, False])
async def test_test_client_get_session_data(
    with_domain: bool,
    anyio_backend: str,
    session_backend_config: "BaseBackendConfig",
    test_client_backend: "AnyIOBackend",
    store: Store,
    test_client_cls: Type[AnyTestClient],
) -> None:
    session_data = {"foo": "bar"}

    if with_domain:
        session_backend_config.domain = "testserver.local"

    @post(path="/test")
    async def set_session_data(request: Request) -> None:
        request.session.update(session_data)

    app = Litestar(
        route_handlers=[set_session_data], middleware=[session_backend_config.middleware], stores={"session": store}
    )

    async with maybe_async_cm(
        test_client_cls(app=app, session_config=session_backend_config, backend=test_client_backend)  # pyright: ignore
    ) as client:
        await maybe_async(client.post("/test"))  # type: ignore[attr-defined]
        assert await maybe_async(client.get_session_data()) == session_data  # type: ignore[attr-defined]


async def test_use_testclient_in_endpoint(
    test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient]
) -> None:
    """this test is taken from starlette."""

    @get("/")
    def mock_service_endpoint() -> dict:
        return {"mock": "example"}

    mock_service = Litestar(route_handlers=[mock_service_endpoint])

    @get("/")
    async def homepage() -> Any:
        local_client = test_client_cls(mock_service, backend=test_client_backend)
        local_response = await maybe_async(local_client.get("/"))
        return local_response.json()  # type: ignore[union-attr]

    app = Litestar(route_handlers=[homepage])

    client = test_client_cls(app)
    response = await maybe_async(client.get("/"))
    assert response.json() == {"mock": "example"}  # type: ignore[union-attr]


def raise_error(app: Litestar) -> NoReturn:
    raise RuntimeError()


async def test_error_handling_on_startup(
    test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient]
) -> None:
    with pytest.raises(_ExceptionGroup):
        async with maybe_async_cm(
            test_client_cls(Litestar(on_startup=[raise_error]), backend=test_client_backend)  # pyright: ignore
        ):
            pass


async def test_error_handling_on_shutdown(
    test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient]
) -> None:
    with pytest.raises(RuntimeError):
        async with maybe_async_cm(
            test_client_cls(Litestar(on_shutdown=[raise_error]), backend=test_client_backend)  # pyright: ignore
        ):
            pass


@pytest.mark.parametrize("method", ["get", "post", "put", "patch", "delete", "head", "options"])
async def test_client_interface(
    method: str, test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient]
) -> None:
    async def asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> None:
        start_event: HTTPResponseStartEvent = {
            "type": "http.response.start",
            "status": HTTP_200_OK,
            "headers": [(b"content-type", b"text/plain")],
        }
        await send(start_event)
        body_event: HTTPResponseBodyEvent = {"type": "http.response.body", "body": b"", "more_body": False}
        await send(body_event)

    client = test_client_cls(asgi_app, backend=test_client_backend)
    if method == "get":
        response = await maybe_async(client.get("/"))
    elif method == "post":
        response = await maybe_async(client.post("/"))
    elif method == "put":
        response = await maybe_async(client.put("/"))
    elif method == "patch":
        response = await maybe_async(client.patch("/"))
    elif method == "delete":
        response = await maybe_async(client.delete("/"))
    elif method == "head":
        response = await maybe_async(client.head("/"))
    else:
        response = await maybe_async(client.options("/"))
    assert response.status_code == HTTP_200_OK  # type: ignore[union-attr]


def test_warns_problematic_domain(test_client_cls: Type[AnyTestClient]) -> None:
    with pytest.warns(UserWarning):
        test_client_cls(app=mock_asgi_app, base_url="http://testserver")


@pytest.mark.parametrize("method", ["get", "post", "put", "patch", "delete", "head", "options"])
async def test_client_interface_context_manager(
    method: str, test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient]
) -> None:
    class MockController(Controller):
        @get("/")
        def mock_service_endpoint_get(self) -> dict:
            return {"mock": "example"}

        @post("/")
        def mock_service_endpoint_post(self) -> dict:
            return {"mock": "example"}

        @put("/")
        def mock_service_endpoint_put(self) -> None: ...

        @patch("/")
        def mock_service_endpoint_patch(self) -> None: ...

        @delete("/")
        def mock_service_endpoint_delete(self) -> None: ...

        @head("/")
        def mock_service_endpoint_head(self) -> None: ...

    mock_service = Litestar(route_handlers=[MockController])
    async with maybe_async_cm(test_client_cls(mock_service, backend=test_client_backend)) as client:  # pyright: ignore
        if method == "get":
            response = await maybe_async(client.get("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_200_OK  # pyright: ignore
        elif method == "post":
            response = await maybe_async(client.post("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_201_CREATED  # pyright: ignore
        elif method == "put":
            response = await maybe_async(client.put("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_200_OK  # pyright: ignore
        elif method == "patch":
            response = await maybe_async(client.patch("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_200_OK  # pyright: ignore
        elif method == "delete":
            response = await maybe_async(client.delete("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_204_NO_CONTENT  # pyright: ignore
        elif method == "head":
            response = await maybe_async(client.head("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_200_OK  # pyright: ignore
        else:
            response = await maybe_async(client.options("/"))  # type: ignore[attr-defined]
            assert response.status_code == HTTP_204_NO_CONTENT  # pyright: ignore


@pytest.mark.parametrize("block,timeout", [(False, None), (False, 0.001), (True, 0.001)])
@pytest.mark.parametrize(
    "receive_method",
    [
        WebSocketTestSession.receive,
        WebSocketTestSession.receive_json,
        WebSocketTestSession.receive_text,
        WebSocketTestSession.receive_bytes,
    ],
)
def test_websocket_test_session_block_timeout(
    receive_method: Callable[..., Any], block: bool, timeout: Optional[float], anyio_backend: "AnyIOBackend"
) -> None:
    @websocket()
    async def handler(socket: WebSocket) -> None:
        await socket.accept()

    with pytest.raises(Empty):
        with create_test_client(handler, backend=anyio_backend) as client, client.websocket_connect("/") as ws:
            receive_method(ws, timeout=timeout, block=block)


def test_websocket_accept_timeout(anyio_backend: "AnyIOBackend") -> None:
    @websocket()
    async def handler(socket: WebSocket) -> None:
        pass

    with create_test_client(handler, backend=anyio_backend, timeout=0.1) as client, pytest.raises(
        Empty
    ), client.websocket_connect("/"):
        pass


@pytest.mark.parametrize("block,timeout", [(False, None), (False, 0.001), (True, 0.001)])
@pytest.mark.parametrize(
    "receive_method",
    [
        WebSocketTestSession.receive,
        WebSocketTestSession.receive_json,
        WebSocketTestSession.receive_text,
        WebSocketTestSession.receive_bytes,
    ],
)
async def test_websocket_test_session_block_timeout_async(
    receive_method: Callable[..., Any], block: bool, timeout: Optional[float], anyio_backend: "AnyIOBackend"
) -> None:
    @websocket()
    async def handler(socket: WebSocket) -> None:
        await socket.accept()

    with pytest.raises(Empty):
        async with create_async_test_client(handler, backend=anyio_backend) as client:
            with await client.websocket_connect("/") as ws:
                receive_method(ws, timeout=timeout, block=block)


async def test_websocket_accept_timeout_async(anyio_backend: "AnyIOBackend") -> None:
    @websocket()
    async def handler(socket: WebSocket) -> None:
        pass

    async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client:
        with pytest.raises(Empty):
            with await client.websocket_connect("/"):
                pass


async def test_websocket_connect_async(anyio_backend: "AnyIOBackend") -> None:
    @websocket()
    async def handler(socket: WebSocket) -> None:
        await socket.accept()
        data = await socket.receive_json()
        await socket.send_json(data)
        await socket.close()

    async with create_async_test_client(handler, backend=anyio_backend, timeout=0.1) as client:
        with await client.websocket_connect("/", subprotocols="wamp") as ws:
            ws.send_json({"data": "123"})
            data = ws.receive_json()
            assert data == {"data": "123"}