File: httpx_stubs.py

package info (click to toggle)
vcr.py 4.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,008 kB
  • sloc: python: 5,695; makefile: 179; sh: 1
file content (170 lines) | stat: -rw-r--r-- 5,578 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import functools
import logging
from unittest.mock import patch, MagicMock

import httpx
from vcr.request import Request as VcrRequest
from vcr.errors import CannotOverwriteExistingCassetteException
import inspect

_httpx_signature = inspect.signature(httpx.Client.request)

try:
    HTTPX_REDIRECT_PARAM = _httpx_signature.parameters["follow_redirects"]
except KeyError:
    HTTPX_REDIRECT_PARAM = _httpx_signature.parameters["allow_redirects"]


_logger = logging.getLogger(__name__)


def _transform_headers(httpx_response):
    """
    Some headers can appear multiple times, like "Set-Cookie".
    Therefore transform to every header key to list of values.
    """

    out = {}
    for key, var in httpx_response.headers.raw:
        decoded_key = key.decode("utf-8")
        out.setdefault(decoded_key, [])
        out[decoded_key].append(var.decode("utf-8"))
    return out


def _to_serialized_response(httpx_response):
    return {
        "status_code": httpx_response.status_code,
        "http_version": httpx_response.http_version,
        "headers": _transform_headers(httpx_response),
        "content": httpx_response.content.decode("utf-8", "ignore"),
    }


def _from_serialized_headers(headers):
    """
    httpx accepts headers as list of tuples of header key and value.
    """

    header_list = []
    for key, values in headers.items():
        for v in values:
            header_list.append((key, v))
    return header_list


@patch("httpx.Response.close", MagicMock())
@patch("httpx.Response.read", MagicMock())
def _from_serialized_response(request, serialized_response, history=None):
    content = serialized_response.get("content").encode()
    response = httpx.Response(
        status_code=serialized_response.get("status_code"),
        request=request,
        headers=_from_serialized_headers(serialized_response.get("headers")),
        content=content,
        history=history or [],
    )
    response._content = content
    return response


def _make_vcr_request(httpx_request, **kwargs):
    body = httpx_request.read().decode("utf-8")
    uri = str(httpx_request.url)
    headers = dict(httpx_request.headers)
    return VcrRequest(httpx_request.method, uri, body, headers)


def _shared_vcr_send(cassette, real_send, *args, **kwargs):
    real_request = args[1]

    vcr_request = _make_vcr_request(real_request, **kwargs)

    if cassette.can_play_response_for(vcr_request):
        return vcr_request, _play_responses(cassette, real_request, vcr_request, args[0], kwargs)

    if cassette.write_protected and cassette.filter_request(vcr_request):
        raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)

    _logger.info("%s not in cassette, sending to real server", vcr_request)
    return vcr_request, None


def _record_responses(cassette, vcr_request, real_response):
    for past_real_response in real_response.history:
        past_vcr_request = _make_vcr_request(past_real_response.request)
        cassette.append(past_vcr_request, _to_serialized_response(past_real_response))

    if real_response.history:
        # If there was a redirection keep we want the request which will hold the
        # final redirect value
        vcr_request = _make_vcr_request(real_response.request)

    cassette.append(vcr_request, _to_serialized_response(real_response))
    return real_response


def _play_responses(cassette, request, vcr_request, client, kwargs):
    history = []

    allow_redirects = kwargs.get(
        HTTPX_REDIRECT_PARAM.name,
        HTTPX_REDIRECT_PARAM.default,
    )
    vcr_response = cassette.play_response(vcr_request)
    response = _from_serialized_response(request, vcr_response)

    while allow_redirects and 300 <= response.status_code <= 399:
        next_url = response.headers.get("location")
        if not next_url:
            break

        vcr_request = VcrRequest("GET", next_url, None, dict(response.headers))
        vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]

        history.append(response)
        # add cookies from response to session cookie store
        client.cookies.extract_cookies(response)

        vcr_response = cassette.play_response(vcr_request)
        response = _from_serialized_response(vcr_request, vcr_response, history)

    return response


async def _async_vcr_send(cassette, real_send, *args, **kwargs):
    vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
    if response:
        # add cookies from response to session cookie store
        args[0].cookies.extract_cookies(response)
        return response

    real_response = await real_send(*args, **kwargs)
    return _record_responses(cassette, vcr_request, real_response)


def async_vcr_send(cassette, real_send):
    @functools.wraps(real_send)
    def _inner_send(*args, **kwargs):
        return _async_vcr_send(cassette, real_send, *args, **kwargs)

    return _inner_send


def _sync_vcr_send(cassette, real_send, *args, **kwargs):
    vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
    if response:
        # add cookies from response to session cookie store
        args[0].cookies.extract_cookies(response)
        return response

    real_response = real_send(*args, **kwargs)
    return _record_responses(cassette, vcr_request, real_response)


def sync_vcr_send(cassette, real_send):
    @functools.wraps(real_send)
    def _inner_send(*args, **kwargs):
        return _sync_vcr_send(cassette, real_send, *args, **kwargs)

    return _inner_send