File: _request_matcher.py

package info (click to toggle)
pytest-httpx 0.36.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 436 kB
  • sloc: python: 4,734; makefile: 3
file content (265 lines) | stat: -rw-r--r-- 10,022 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import json
import re
from typing import Optional, Union, Any
from re import Pattern

import httpx
from httpx import QueryParams

from pytest_httpx._httpx_internals import _proxy_url
from pytest_httpx._options import _HTTPXMockOptions


def _url_match(
    url_to_match: Union[Pattern[str], httpx.URL],
    received: httpx.URL,
    params: Optional[dict[str, Union[str | list[str]]]],
) -> bool:
    if isinstance(url_to_match, re.Pattern):
        return url_to_match.match(str(received)) is not None

    # Compare query parameters apart as order of parameters should not matter
    received_params = to_params_dict(received.params)
    if params is None:
        params = to_params_dict(url_to_match.params)

    # Remove the query parameters from the original URL to compare everything besides query parameters
    received_url = received.copy_with(query=None)
    url = url_to_match.copy_with(query=None)

    return (received_params == params) and (url == received_url)


def to_params_dict(params: QueryParams) -> dict[str, Union[str | list[str]]]:
    """Convert query parameters to a dict where the value is a string if the parameter has a single value and a list of string otherwise."""
    d = {}
    for key in params:
        values = params.get_list(key)
        d[key] = values if len(values) > 1 else values[0]
    return d


class _RequestMatcher:
    def __init__(
        self,
        options: _HTTPXMockOptions,
        url: Optional[Union[str, Pattern[str], httpx.URL]] = None,
        method: Optional[str] = None,
        proxy_url: Optional[Union[str, Pattern[str], httpx.URL]] = None,
        match_headers: Optional[dict[str, Any]] = None,
        match_content: Optional[bytes] = None,
        match_json: Optional[Any] = None,
        match_data: Optional[dict[str, Any]] = None,
        match_files: Optional[Any] = None,
        match_extensions: Optional[dict[str, Any]] = None,
        match_params: Optional[dict[str, Union[str | list[str]]]] = None,
        is_optional: Optional[bool] = None,
        is_reusable: Optional[bool] = None,
    ):
        self._options = options
        self.nb_calls = 0
        self.url = httpx.URL(url) if url and isinstance(url, str) else url
        self.method = method.upper() if method else method
        self.headers = match_headers
        self.content = match_content
        self.json = match_json
        self.data = match_data
        self.files = match_files
        self.params = match_params
        self.proxy_url = (
            httpx.URL(proxy_url)
            if proxy_url and isinstance(proxy_url, str)
            else proxy_url
        )
        self.extensions = match_extensions
        self.is_optional = (
            not options.assert_all_responses_were_requested
            if is_optional is None
            else is_optional
        )
        self.is_reusable = (
            options.can_send_already_matched_responses
            if is_reusable is None
            else is_reusable
        )
        if self._is_matching_body_more_than_one_way():
            raise ValueError(
                "Only one way of matching against the body can be provided. "
                "If you want to match against the JSON decoded representation, use match_json. "
                "If you want to match against the multipart representation, use match_files (and match_data). "
                "Otherwise, use match_content."
            )
        if self.params and not self.url:
            raise ValueError("URL must be provided when match_params is used.")
        if self.params and isinstance(self.url, re.Pattern):
            raise ValueError(
                "match_params cannot be used in addition to regex URL. Request this feature via https://github.com/Colin-b/pytest_httpx/issues/new?title=Regex%20URL%20should%20allow%20match_params&body=Hi,%20I%20need%20a%20regex%20to%20match%20the%20non%20query%20part%20of%20the%20URL%20only"
            )
        if self._is_matching_params_more_than_one_way():
            raise ValueError(
                "Provided URL must not contain any query parameter when match_params is used."
            )
        if self.data and not self.files:
            raise ValueError(
                "match_data is meant to be used for multipart matching (in conjunction with match_files)."
                "Use match_content to match url encoded data."
            )

    def expect_body(self) -> bool:
        matching_ways = [
            self.content is not None,
            self.json is not None,
            self.files is not None,
        ]
        return sum(matching_ways) == 1

    def _is_matching_body_more_than_one_way(self) -> bool:
        matching_ways = [
            self.content is not None,
            self.json is not None,
            self.files is not None,
        ]
        return sum(matching_ways) > 1

    def _is_matching_params_more_than_one_way(self) -> bool:
        url_has_params = (
            bool(self.url.params)
            if (self.url and isinstance(self.url, httpx.URL))
            else False
        )
        matching_ways = [
            self.params is not None,
            url_has_params,
        ]
        return sum(matching_ways) > 1

    def match(
        self,
        real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport],
        request: httpx.Request,
    ) -> bool:
        return (
            self._url_match(request)
            and self._method_match(request)
            and self._headers_match(request)
            and self._content_match(request)
            and self._proxy_match(real_transport)
            and self._extensions_match(request)
        )

    def _url_match(self, request: httpx.Request) -> bool:
        if not self.url:
            return True

        return _url_match(self.url, request.url, self.params)

    def _method_match(self, request: httpx.Request) -> bool:
        if not self.method:
            return True

        return request.method == self.method

    def _headers_match(self, request: httpx.Request) -> bool:
        if not self.headers:
            return True

        encoding = request.headers.encoding
        request_headers = {}
        # Can be cleaned based on the outcome of https://github.com/encode/httpx/discussions/2841
        for raw_name, raw_value in request.headers.raw:
            if raw_name in request_headers:
                request_headers[raw_name] += b", " + raw_value
            else:
                request_headers[raw_name] = raw_value

        return all(
            request_headers.get(header_name.encode(encoding))
            == header_value.encode(encoding)
            for header_name, header_value in self.headers.items()
        )

    def _content_match(self, request: httpx.Request) -> bool:
        if self.content is not None:
            return request.content == self.content

        if self.json is not None:
            try:
                # httpx._content.encode_json hard codes utf-8 encoding.
                return json.loads(request.content.decode("utf-8")) == self.json
            except json.decoder.JSONDecodeError:
                return False

        if self.files:
            if not (
                boundary_matched := re.match(b"^--([0-9a-f]*)\r\n", request.content)
            ):
                return False
            # Ensure we re-use the same boundary for comparison
            boundary = boundary_matched.group(1)
            # Prevent internal httpx changes from impacting users not matching on files
            from httpx._multipart import MultipartStream

            multipart_content = b"".join(
                MultipartStream(self.data or {}, self.files, boundary)
            )
            return request.content == multipart_content

        return True

    def _proxy_match(
        self, real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]
    ) -> bool:
        if not self.proxy_url:
            return True

        if real_proxy_url := _proxy_url(real_transport):
            return _url_match(self.proxy_url, real_proxy_url, params=None)

        return False

    def _extensions_match(self, request: httpx.Request) -> bool:
        if not self.extensions:
            return True

        return all(
            request.extensions.get(extension_name) == extension_value
            for extension_name, extension_value in self.extensions.items()
        )

    def should_have_matched(self) -> bool:
        """Return True if the matcher did not serve its purpose."""
        return not self.is_optional and not self.nb_calls

    def __str__(self) -> str:
        if self.is_reusable:
            matcher_description = f"Match {self.method or 'every'} request"
        else:
            matcher_description = "Already matched" if self.nb_calls else "Match"
            matcher_description += f" {self.method or 'any'} request"
        if self.url:
            matcher_description += f" on {self.url}"
        if extra_description := self._extra_description():
            matcher_description += f" with {extra_description}"
        return matcher_description

    def _extra_description(self) -> str:
        extra_description = []

        if self.params:
            extra_description.append(f"{self.params} query parameters")
        if self.headers:
            extra_description.append(f"{self.headers} headers")
        if self.content is not None:
            extra_description.append(f"{self.content} body")
        if self.json is not None:
            extra_description.append(f"{self.json} json body")
        if self.data is not None:
            extra_description.append(f"{self.data} multipart data")
        if self.files is not None:
            extra_description.append(f"{self.files} files")
        if self.proxy_url:
            extra_description.append(f"{self.proxy_url} proxy URL")
        if self.extensions:
            extra_description.append(f"{self.extensions} extensions")

        return " and ".join(extra_description)