File: __init__.py

package info (click to toggle)
python-azure 20230112%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 749,544 kB
  • sloc: python: 6,815,827; javascript: 287; makefile: 195; xml: 109; sh: 105
file content (218 lines) | stat: -rw-r--r-- 7,478 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Stubs for aiohttp HTTP clients"""
from __future__ import absolute_import

import asyncio
import functools
import logging
import json

from aiohttp import ClientConnectionError, ClientResponse, RequestInfo, streams
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL

from vcr.request import Request

log = logging.getLogger(__name__)


class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
    pass


class MockClientResponse(ClientResponse):
    def __init__(self, method, url, request_info=None):
        super().__init__(
            method=method,
            url=url,
            writer=None,
            continue100=None,
            timer=None,
            request_info=request_info,
            traces=None,
            loop=asyncio.get_event_loop(),
            session=None,
        )
        self._content = None

    async def json(self, *, encoding="utf-8", loads=json.loads, **kwargs):  # NOQA: E999
        stripped = self._body.strip()
        if not stripped:
            return None

        return loads(stripped.decode(encoding))

    async def text(self, encoding="utf-8", errors="strict"):
        return self._body.decode(encoding, errors=errors)

    async def read(self):
        return self._body

    def release(self):
        pass

    @property
    def content(self):
        if not self._content:
            self._content = MockStream()
            self._content.feed_data(self._body)
            self._content.feed_eof()
        return self._content


def build_response(vcr_request, vcr_response, history):
    request_info = RequestInfo(
        url=URL(vcr_request.url),
        method=vcr_request.method,
        headers=CIMultiDictProxy(CIMultiDict(vcr_request.headers)),
        real_url=URL(vcr_request.url),
    )
    response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")), request_info=request_info)
    response.status = vcr_response["status"]["code"]
    response._body = vcr_response["body"].get("string", b"")
    response.reason = vcr_response["status"]["message"]
    response._headers = CIMultiDictProxy(CIMultiDict(vcr_response["headers"]))
    response._history = tuple(history)

    return response


def _serialize_headers(headers):
    """Serialize CIMultiDictProxy to a pickle-able dict because proxy
        objects forbid pickling:

        https://github.com/aio-libs/multidict/issues/340
    """
    # Mark strings as keys so 'istr' types don't show up in
    # the cassettes as comments.
    return {str(k): v for k, v in headers.items()}


def play_responses(cassette, vcr_request):
    history = []
    vcr_response = cassette.play_response(vcr_request)
    response = build_response(vcr_request, vcr_response, history)

    # If we're following redirects, continue playing until we reach
    # our final destination.
    while 300 <= response.status <= 399:
        new_location = response.headers["location"]
        potential_next_url = URL(new_location)
        next_url = (potential_next_url
            if potential_next_url.is_absolute()
            else URL(response.url).with_path(new_location))

        # Make a stub VCR request that we can then use to look up the recorded
        # VCR request saved to the cassette. This feels a little hacky and
        # may have edge cases based on the headers we're providing (e.g. if
        # there's a matcher that is used to filter by headers).
        vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers))
        vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]

        # Tack on the response we saw from the redirect into the history
        # list that is added on to the final response.
        history.append(response)
        vcr_response = cassette.play_response(vcr_request)
        response = build_response(vcr_request, vcr_response, history)

    return response


async def record_response(cassette, vcr_request, response):
    """Record a VCR request-response chain to the cassette."""

    try:
        data = await response.read()
        body = {"string": (data)}
        response.content = MockStream()
        response.content.feed_data(data)
        response.content.feed_eof()
    # aiohttp raises a ClientConnectionError on reads when
    # there is no body. We can use this to know to not write one.
    except ClientConnectionError:
        body = {}

    vcr_response = {
        "status": {"code": response.status, "message": response.reason},
        "headers": _serialize_headers(response.headers),
        "body": body,  # NOQA: E999
        "url": str(response.url),
    }

    cassette.append(vcr_request, vcr_response)


async def record_responses(cassette, vcr_request, response):
    """Because aiohttp follows redirects by default, we must support
        them by default. This method is used to write individual
        request-response chains that were implicitly followed to get
        to the final destination.
    """

    for past_response in response.history:
        aiohttp_request = past_response.request_info

        # No data because it's following a redirect.
        past_request = Request(
            aiohttp_request.method,
            str(aiohttp_request.url),
            None,
            _serialize_headers(aiohttp_request.headers),
        )
        await record_response(cassette, past_request, past_response)

    # If we're following redirects, then the last request-response
    # we record is the one attached to the `response`.
    if response.history:
        aiohttp_request = response.request_info
        vcr_request = Request(
            aiohttp_request.method,
            str(aiohttp_request.url),
            None,
            _serialize_headers(aiohttp_request.headers),
        )

    await record_response(cassette, vcr_request, response)


def vcr_request(cassette, real_request):
    @functools.wraps(real_request)
    async def new_request(self, method, url, **kwargs):
        headers = kwargs.get("headers")
        auth = kwargs.get("auth")
        headers = self._prepare_headers(headers)
        data = kwargs.get("data", kwargs.get("json"))
        params = kwargs.get("params")

        if auth is not None:
            headers["AUTHORIZATION"] = auth.encode()

        request_url = URL(url)
        if params:
            for k, v in params.items():
                params[k] = str(v)
            request_url = URL(url).with_query(params)

        vcr_request = Request(method, str(request_url), data, headers)

        if cassette.can_play_response_for(vcr_request):
            return play_responses(cassette, vcr_request)

        if cassette.write_protected and cassette.filter_request(vcr_request):
            response = MockClientResponse(method, URL(url))
            response.status = 599
            msg = (
                "No match for the request {!r} was found. Can't overwrite "
                "existing cassette {!r} in your current record mode {!r}."
            )
            msg = msg.format(vcr_request, cassette._path, cassette.record_mode)
            response._body = msg.encode()
            response.close()
            return response

        log.info("%s not in cassette, sending to real server", vcr_request)

        response = await real_request(self, method, url, **kwargs)  # NOQA: E999
        await record_responses(cassette, vcr_request, response)
        return response

    return new_request