File: request_mock.py

package info (click to toggle)
python-stripe 12.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,864 kB
  • sloc: python: 157,573; makefile: 13; sh: 9
file content (256 lines) | stat: -rw-r--r-- 8,501 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import json

import stripe
from stripe import util
from stripe._stripe_response import StripeResponse, StripeStreamResponse


class RequestMock(object):
    def __init__(self, mocker):
        self._mocker = mocker

        self._real_request = stripe.api_requestor.APIRequestor.request
        self._real_request_async = (
            stripe.api_requestor.APIRequestor.request_async
        )
        self._real_request_stream = (
            stripe.api_requestor.APIRequestor.request_stream
        )
        self._stub_request_handler = StubRequestHandler()

        self.constructor_patcher = self._mocker.patch(
            "stripe.api_requestor.APIRequestor.__init__",
            side_effect=stripe.api_requestor.APIRequestor.__init__,
            autospec=True,
        )

        self.request_patcher = self._mocker.patch(
            "stripe.api_requestor.APIRequestor.request",
            side_effect=self._patched_request,
            autospec=True,
        )

        self.request_async_patcher = self._mocker.patch(
            "stripe.api_requestor.APIRequestor.request_async",
            side_effect=self._patched_request_async,
            autospec=True,
        )

        self.request_stream_patcher = self._mocker.patch(
            "stripe.api_requestor.APIRequestor.request_stream",
            side_effect=self._patched_request_stream,
            autospec=True,
        )

    def _patched_request(self, requestor, method, url, *args, **kwargs):
        response = self._stub_request_handler.get_response(
            method, url, expect_stream=False
        )
        if response is not None:
            return response, stripe.api_key

        return self._real_request(requestor, method, url, *args, **kwargs)

    async def _patched_request_async(
        self, requestor, method, url, *args, **kwargs
    ):
        response = self._stub_request_handler.get_response(
            method, url, expect_stream=False
        )
        if response is not None:
            return response, stripe.api_key

        return self._real_request_async(
            requestor, method, url, *args, **kwargs
        )

    def _patched_request_stream(self, requestor, method, url, *args, **kwargs):
        response = self._stub_request_handler.get_response(
            method, url, expect_stream=True
        )
        if response is not None:
            return response, stripe.api_key

        return self._real_request_stream(
            requestor, method, url, *args, **kwargs
        )

    def stub_request(self, method, url, rbody={}, rcode=200, rheaders={}):
        self._stub_request_handler.register(
            method, url, rbody, rcode, rheaders, is_streaming=False
        )

    def stub_request_stream(
        self, method, url, rbody={}, rcode=200, rheaders={}
    ):
        self._stub_request_handler.register(
            method, url, rbody, rcode, rheaders, is_streaming=True
        )

    def assert_api_base(self, expected_api_base):
        # Note that this method only checks that an API base was provided
        # as a keyword argument in APIRequestor's constructor, not as a
        # positional argument.

        if "api_base" not in self.constructor_patcher.call_args[1]:
            msg = (
                "Expected APIRequestor to have been constructed with "
                "api_base='%s'. No API base was provided." % expected_api_base
            )
            raise AssertionError(msg)

        actual_api_base = self.constructor_patcher.call_args[1]["api_base"]
        if actual_api_base != expected_api_base:
            msg = (
                "Expected APIRequestor to have been constructed with "
                "api_base='%s'. Constructed with api_base='%s' "
                "instead." % (expected_api_base, actual_api_base)
            )
            raise AssertionError(msg)

    def assert_api_version(self, expected_api_version):
        # Note that this method only checks that an API version was provided
        # as a keyword argument in APIRequestor's constructor, not as a
        # positional argument.

        if "api_version" not in self.constructor_patcher.call_args[1]:
            msg = (
                "Expected APIRequestor to have been constructed with "
                "api_version='%s'. No API version was provided."
                % expected_api_version
            )
            raise AssertionError(msg)

        actual_api_version = self.constructor_patcher.call_args[1][
            "api_version"
        ]
        if actual_api_version != expected_api_version:
            msg = (
                "Expected APIRequestor to have been constructed with "
                "api_version='%s'. Constructed with api_version='%s' "
                "instead." % (expected_api_version, actual_api_version)
            )
            raise AssertionError(msg)

    def assert_requested(
        self,
        method,
        url,
        params=None,
        headers=None,
        api_mode=None,
        _usage=None,
    ):
        self.assert_requested_internal(
            self.request_patcher,
            method,
            url,
            params,
            headers,
            api_mode,
            _usage,
        )

    def assert_requested_stream(
        self,
        method,
        url,
        params=None,
        headers=None,
        api_mode=None,
        _usage=None,
    ):
        self.assert_requested_internal(
            self.request_stream_patcher,
            method,
            url,
            params,
            headers,
            api_mode,
            _usage,
        )

    def assert_requested_internal(
        self, patcher, method, url, params, headers, api_mode, usage
    ):
        params = params or self._mocker.ANY
        headers = headers or self._mocker.ANY
        api_mode = api_mode or self._mocker.ANY
        usage = usage or self._mocker.ANY
        called = False
        exception = None

        # Sadly, ANY does not match a missing optional argument, so we
        # check all the possible signatures of the request method
        possible_called_args = [
            (self._mocker.ANY, method, url),
            (self._mocker.ANY, method, url, params),
            (self._mocker.ANY, method, url, params, headers),
            (self._mocker.ANY, method, url, params, headers, api_mode),
        ]

        possible_called_kwargs = [{}, {"_usage": usage}]

        for args in possible_called_args:
            for kwargs in possible_called_kwargs:
                try:
                    patcher.assert_called_with(*args, **kwargs)
                except AssertionError as e:
                    exception = e
                else:
                    called = True
                    break

        if not called:
            raise exception

    def assert_no_request(self):
        if self.request_patcher.call_count != 0:
            msg = (
                "Expected 'request' to not have been called. "
                "Called %s times." % (self.request_patcher.call_count)
            )
            raise AssertionError(msg)

    def assert_no_request_stream(self):
        if self.request_stream_patcher.call_count != 0:
            msg = (
                "Expected 'request_stream' to not have been called. "
                "Called %s times." % (self.request_stream_patcher.call_count)
            )
            raise AssertionError(msg)

    def reset_mock(self):
        self.request_patcher.reset_mock()
        self.request_stream_patcher.reset_mock()


class StubRequestHandler(object):
    def __init__(self):
        self._entries = {}

    def register(
        self, method, url, rbody={}, rcode=200, rheaders={}, is_streaming=False
    ):
        self._entries[(method, url)] = (rbody, rcode, rheaders, is_streaming)

    def get_response(self, method, url, expect_stream=False):
        if (method, url) in self._entries:
            rbody, rcode, rheaders, is_streaming = self._entries.pop(
                (method, url)
            )

            if expect_stream != is_streaming:
                return None

            if not isinstance(rbody, str):
                rbody = json.dumps(rbody)
            if is_streaming:
                stripe_response = StripeStreamResponse(
                    util.io.BytesIO(str.encode(rbody)), rcode, rheaders
                )
            else:
                stripe_response = StripeResponse(rbody, rcode, rheaders)
            return stripe_response

        return None