1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
|
import json
import re
from typing import Optional, Union, Any
from re import Pattern
import httpx
from httpx import QueryParams
from pytest_httpx._httpx_internals import _proxy_url
from pytest_httpx._options import _HTTPXMockOptions
def _url_match(
url_to_match: Union[Pattern[str], httpx.URL],
received: httpx.URL,
params: Optional[dict[str, Union[str | list[str]]]],
) -> bool:
if isinstance(url_to_match, re.Pattern):
return url_to_match.match(str(received)) is not None
# Compare query parameters apart as order of parameters should not matter
received_params = to_params_dict(received.params)
if params is None:
params = to_params_dict(url_to_match.params)
# Remove the query parameters from the original URL to compare everything besides query parameters
received_url = received.copy_with(query=None)
url = url_to_match.copy_with(query=None)
return (received_params == params) and (url == received_url)
def to_params_dict(params: QueryParams) -> dict[str, Union[str | list[str]]]:
"""Convert query parameters to a dict where the value is a string if the parameter has a single value and a list of string otherwise."""
d = {}
for key in params:
values = params.get_list(key)
d[key] = values if len(values) > 1 else values[0]
return d
class _RequestMatcher:
def __init__(
self,
options: _HTTPXMockOptions,
url: Optional[Union[str, Pattern[str], httpx.URL]] = None,
method: Optional[str] = None,
proxy_url: Optional[Union[str, Pattern[str], httpx.URL]] = None,
match_headers: Optional[dict[str, Any]] = None,
match_content: Optional[bytes] = None,
match_json: Optional[Any] = None,
match_data: Optional[dict[str, Any]] = None,
match_files: Optional[Any] = None,
match_extensions: Optional[dict[str, Any]] = None,
match_params: Optional[dict[str, Union[str | list[str]]]] = None,
is_optional: Optional[bool] = None,
is_reusable: Optional[bool] = None,
):
self._options = options
self.nb_calls = 0
self.url = httpx.URL(url) if url and isinstance(url, str) else url
self.method = method.upper() if method else method
self.headers = match_headers
self.content = match_content
self.json = match_json
self.data = match_data
self.files = match_files
self.params = match_params
self.proxy_url = (
httpx.URL(proxy_url)
if proxy_url and isinstance(proxy_url, str)
else proxy_url
)
self.extensions = match_extensions
self.is_optional = (
not options.assert_all_responses_were_requested
if is_optional is None
else is_optional
)
self.is_reusable = (
options.can_send_already_matched_responses
if is_reusable is None
else is_reusable
)
if self._is_matching_body_more_than_one_way():
raise ValueError(
"Only one way of matching against the body can be provided. "
"If you want to match against the JSON decoded representation, use match_json. "
"If you want to match against the multipart representation, use match_files (and match_data). "
"Otherwise, use match_content."
)
if self.params and not self.url:
raise ValueError("URL must be provided when match_params is used.")
if self.params and isinstance(self.url, re.Pattern):
raise ValueError(
"match_params cannot be used in addition to regex URL. Request this feature via https://github.com/Colin-b/pytest_httpx/issues/new?title=Regex%20URL%20should%20allow%20match_params&body=Hi,%20I%20need%20a%20regex%20to%20match%20the%20non%20query%20part%20of%20the%20URL%20only"
)
if self._is_matching_params_more_than_one_way():
raise ValueError(
"Provided URL must not contain any query parameter when match_params is used."
)
if self.data and not self.files:
raise ValueError(
"match_data is meant to be used for multipart matching (in conjunction with match_files)."
"Use match_content to match url encoded data."
)
def expect_body(self) -> bool:
matching_ways = [
self.content is not None,
self.json is not None,
self.files is not None,
]
return sum(matching_ways) == 1
def _is_matching_body_more_than_one_way(self) -> bool:
matching_ways = [
self.content is not None,
self.json is not None,
self.files is not None,
]
return sum(matching_ways) > 1
def _is_matching_params_more_than_one_way(self) -> bool:
url_has_params = (
bool(self.url.params)
if (self.url and isinstance(self.url, httpx.URL))
else False
)
matching_ways = [
self.params is not None,
url_has_params,
]
return sum(matching_ways) > 1
def match(
self,
real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport],
request: httpx.Request,
) -> bool:
return (
self._url_match(request)
and self._method_match(request)
and self._headers_match(request)
and self._content_match(request)
and self._proxy_match(real_transport)
and self._extensions_match(request)
)
def _url_match(self, request: httpx.Request) -> bool:
if not self.url:
return True
return _url_match(self.url, request.url, self.params)
def _method_match(self, request: httpx.Request) -> bool:
if not self.method:
return True
return request.method == self.method
def _headers_match(self, request: httpx.Request) -> bool:
if not self.headers:
return True
encoding = request.headers.encoding
request_headers = {}
# Can be cleaned based on the outcome of https://github.com/encode/httpx/discussions/2841
for raw_name, raw_value in request.headers.raw:
if raw_name in request_headers:
request_headers[raw_name] += b", " + raw_value
else:
request_headers[raw_name] = raw_value
return all(
request_headers.get(header_name.encode(encoding))
== header_value.encode(encoding)
for header_name, header_value in self.headers.items()
)
def _content_match(self, request: httpx.Request) -> bool:
if self.content is not None:
return request.content == self.content
if self.json is not None:
try:
# httpx._content.encode_json hard codes utf-8 encoding.
return json.loads(request.content.decode("utf-8")) == self.json
except json.decoder.JSONDecodeError:
return False
if self.files:
if not (
boundary_matched := re.match(b"^--([0-9a-f]*)\r\n", request.content)
):
return False
# Ensure we re-use the same boundary for comparison
boundary = boundary_matched.group(1)
# Prevent internal httpx changes from impacting users not matching on files
from httpx._multipart import MultipartStream
multipart_content = b"".join(
MultipartStream(self.data or {}, self.files, boundary)
)
return request.content == multipart_content
return True
def _proxy_match(
self, real_transport: Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]
) -> bool:
if not self.proxy_url:
return True
if real_proxy_url := _proxy_url(real_transport):
return _url_match(self.proxy_url, real_proxy_url, params=None)
return False
def _extensions_match(self, request: httpx.Request) -> bool:
if not self.extensions:
return True
return all(
request.extensions.get(extension_name) == extension_value
for extension_name, extension_value in self.extensions.items()
)
def should_have_matched(self) -> bool:
"""Return True if the matcher did not serve its purpose."""
return not self.is_optional and not self.nb_calls
def __str__(self) -> str:
if self.is_reusable:
matcher_description = f"Match {self.method or 'every'} request"
else:
matcher_description = "Already matched" if self.nb_calls else "Match"
matcher_description += f" {self.method or 'any'} request"
if self.url:
matcher_description += f" on {self.url}"
if extra_description := self._extra_description():
matcher_description += f" with {extra_description}"
return matcher_description
def _extra_description(self) -> str:
extra_description = []
if self.params:
extra_description.append(f"{self.params} query parameters")
if self.headers:
extra_description.append(f"{self.headers} headers")
if self.content is not None:
extra_description.append(f"{self.content} body")
if self.json is not None:
extra_description.append(f"{self.json} json body")
if self.data is not None:
extra_description.append(f"{self.data} multipart data")
if self.files is not None:
extra_description.append(f"{self.files} files")
if self.proxy_url:
extra_description.append(f"{self.proxy_url} proxy URL")
if self.extensions:
extra_description.append(f"{self.extensions} extensions")
return " and ".join(extra_description)
|