File: test_query_params.py

package info (click to toggle)
litestar 2.21.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 12,568 kB
  • sloc: python: 70,588; makefile: 254; javascript: 104; sh: 60
file content (234 lines) | stat: -rw-r--r-- 7,084 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from datetime import datetime
from typing import (
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
)
from urllib.parse import urlencode

import pytest
from typing_extensions import Annotated

from litestar import MediaType, Request, get, post
from litestar.datastructures import MultiDict
from litestar.di import Provide
from litestar.params import Parameter
from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST
from litestar.testing import create_test_client


@pytest.mark.parametrize(
    "params_dict,should_raise",
    [
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": ["Nike", "Adidas"],
            },
            False,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": ["Nike", "Adidas", "Rebok"],
            },
            False,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
            },
            True,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": ["Nike", "Adidas", "Rebok", "Polgat"],
            },
            True,
        ),
        (
            {
                "page": 1,
                "pageSize": 101,
                "brands": ["Nike", "Adidas", "Rebok"],
            },
            True,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": [],
            },
            True,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": ["Nike", "Adidas", "Rebok"],
                "from_date": datetime.now().timestamp(),
            },
            False,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": ["Nike", "Adidas", "Rebok"],
                "from_date": datetime.now().timestamp(),
                "to_date": datetime.now().timestamp(),
            },
            False,
        ),
        (
            {
                "page": 1,
                "pageSize": 1,
                "brands": ["Nike"],
                "from_date": datetime.now().timestamp(),
                "to_date": datetime.now().timestamp(),
            },
            False,
        ),
    ],
)
def test_query_params(params_dict: dict, should_raise: bool) -> None:
    test_path = "/test"

    @get(path=test_path)
    def test_method(
        page: int,
        page_size: int = Parameter(query="pageSize", gt=0, le=100),
        brands: List[str] = Parameter(min_items=1, max_items=3),
        from_date: Optional[datetime] = None,
        to_date: Optional[datetime] = None,
    ) -> None:
        assert page
        assert page_size
        assert brands
        assert from_date or from_date is None
        assert to_date or to_date is None

    with create_test_client(test_method) as client:
        response = client.get(f"{test_path}?{urlencode(params_dict, doseq=True)}")
        if should_raise:
            assert response.status_code == HTTP_400_BAD_REQUEST, response.json()
        else:
            assert response.status_code == HTTP_200_OK, response.json()


@pytest.mark.parametrize(
    "expected_type,provided_value,default,expected_response_code",
    [
        (Union[int, List[int]], [1, 2, 3], None, HTTP_200_OK),
        (Union[int, List[int]], [1], None, HTTP_200_OK),
    ],
)
def test_query_param_arrays(expected_type: Any, provided_value: Any, default: Any, expected_response_code: int) -> None:
    test_path = "/test"

    @get(test_path)
    def test_method_with_default(param: Any = default) -> None:
        return None

    @get(test_path)
    def test_method_without_default(param: Any) -> None:
        return None

    test_method = test_method_without_default if default is ... else test_method_with_default
    # Set the type annotation of 'param' in a way mypy can deal with
    test_method.fn.__annotations__["param"] = expected_type

    with create_test_client(test_method) as client:
        params = urlencode({"param": provided_value}, doseq=True)
        response = client.get(f"{test_path}?{params}")
        assert response.status_code == expected_response_code


def test_query_kwarg() -> None:
    test_path = "/test"

    params = urlencode(
        {
            "a": ["foo", "bar"],
            "b": "qux",
        },
        doseq=True,
    )

    @get(test_path)
    def handler(a: List[str], b: List[str], query: MultiDict) -> None:
        assert isinstance(query, MultiDict)
        assert {k: query.getall(k) for k in query} == {"a": ["foo", "bar"], "b": ["qux"]}
        assert isinstance(a, list)
        assert isinstance(b, list)
        assert a == ["foo", "bar"]
        assert b == ["qux"]

    with create_test_client(handler) as client:
        response = client.get(f"{test_path}?{params}")
        assert response.status_code == HTTP_200_OK, response.json()


@pytest.mark.parametrize(
    "values",
    (
        (("first", "x@test.com"), ("second", "aaa")),
        (("first", "&@A.ac"), ("second", "aaa")),
        (("first", "a@A.ac&"), ("second", "aaa")),
        (("first", "a@A&.ac"), ("second", "aaa")),
    ),
)
def test_query_parsing_of_escaped_values(values: Tuple[Tuple[str, str], Tuple[str, str]]) -> None:
    # https://github.com/litestar-org/litestar/issues/915

    request_values: Dict[str, Any] = {}

    @get(path="/handler")
    def handler(request: Request, first: str, second: str) -> None:
        request_values["first"] = first
        request_values["second"] = second
        request_values["query"] = request.query_params

    params = dict(values)

    with create_test_client(handler) as client:
        response = client.get("/handler", params=params)
        assert response.status_code == HTTP_200_OK
        assert request_values["first"] == params["first"]
        assert request_values["second"] == params["second"]
        assert request_values["query"].get("first") == params["first"]
        assert request_values["query"].get("second") == params["second"]


def test_query_param_dependency_with_alias() -> None:
    async def qp_dependency(page_size: int = Parameter(query="pageSize", gt=0, le=100)) -> int:
        return page_size

    @get("/", media_type=MediaType.TEXT)
    def handler(page_size_dep: int) -> str:
        return str(page_size_dep)

    with create_test_client(handler, dependencies={"page_size_dep": Provide(qp_dependency)}) as client:
        response = client.get("/?pageSize=1")
        assert response.status_code == HTTP_200_OK, response.text
        assert response.text == "1"


def test_query_params_with_post() -> None:
    # https://github.com/litestar-org/litestar/issues/3734
    @post()
    async def handler(data: str, secret: Annotated[str, Parameter(query="x-secret")]) -> None:
        return None

    with create_test_client([handler], raise_server_exceptions=True) as client:
        assert client.post("/", json={}).status_code == 400