File: mock_requests_backend.py

package info (click to toggle)
django-anymail 13.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 2,480 kB
  • sloc: python: 27,832; makefile: 132; javascript: 33; sh: 9
file content (295 lines) | stat: -rw-r--r-- 10,440 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import json
from io import BytesIO
from unittest.mock import patch

import requests
from django.core import mail
from django.test import SimpleTestCase

from anymail.exceptions import AnymailAPIError

from .utils import AnymailTestMixin

UNSET = object()


class RequestsBackendMockAPITestCase(AnymailTestMixin, SimpleTestCase):
    """TestCase that mocks API calls through requests"""

    DEFAULT_RAW_RESPONSE = b"""{"subclass": "should override"}"""
    DEFAULT_CONTENT_TYPE = None  # e.g., "application/json"
    DEFAULT_STATUS_CODE = 200  # most APIs use '200 OK' for success

    class MockResponse(requests.Response):
        """requests.request return value mock sufficient for testing"""

        def __init__(
            self,
            status_code=200,
            raw=b"RESPONSE",
            content_type=None,
            encoding="utf-8",
            reason=None,
            test_case=None,
        ):
            super().__init__()
            self.status_code = status_code
            self.encoding = encoding
            self.reason = reason or ("OK" if 200 <= status_code < 300 else "ERROR")
            self.raw = BytesIO(raw)
            if content_type is not None:
                self.headers["Content-Type"] = content_type
            self.test_case = test_case

        @property
        def url(self):
            return self.test_case.get_api_call_arg("url", required=False)

        @url.setter
        def url(self, url):
            if url is not None:
                raise ValueError("MockResponse can't handle url assignment")

    def setUp(self):
        super().setUp()
        self.patch_request = patch("requests.Session.request", autospec=True)
        self.mock_request = self.patch_request.start()
        self.addCleanup(self.patch_request.stop)
        self.set_mock_response()

    def set_mock_response(
        self,
        status_code=UNSET,
        raw=UNSET,
        json_data=UNSET,
        encoding="utf-8",
        content_type=UNSET,
        reason=None,
    ):
        if status_code is UNSET:
            status_code = self.DEFAULT_STATUS_CODE
        if json_data is not UNSET:
            assert raw is UNSET, "provide json_data or raw, not both"
            raw = json.dumps(json_data).encode(encoding)
            if content_type is UNSET:
                content_type = "application/json"
        if raw is UNSET:
            raw = self.DEFAULT_RAW_RESPONSE
        if content_type is UNSET:
            content_type = self.DEFAULT_CONTENT_TYPE
        mock_response = self.MockResponse(
            status_code,
            raw=raw,
            content_type=content_type,
            encoding=encoding,
            reason=reason,
            test_case=self,
        )
        self.mock_request.return_value = mock_response
        return mock_response

    def assert_esp_called(self, url, method="POST"):
        """Verifies the (mock) ESP API was called on endpoint.

        url can be partial, and is just checked against the end of the url requested
        """
        # This assumes the last (or only) call to requests.Session.request
        # is the API call of interest.
        if self.mock_request.call_args is None:
            raise AssertionError("No ESP API was called")
        if method is not None:
            actual_method = self.get_api_call_arg("method")
            if actual_method != method:
                self.fail(
                    "API was not called using %s. (%s was used instead.)"
                    % (method, actual_method)
                )
        if url is not None:
            actual_url = self.get_api_call_arg("url")
            if not actual_url.endswith(url):
                self.fail(
                    "API was not called at %s\n(It was called at %s)"
                    % (url, actual_url)
                )

    def get_api_call_arg(self, kwarg, required=True):
        """Returns an argument passed to the mock ESP API.

        Fails test if API wasn't called.
        """
        if self.mock_request.call_args is None:
            raise AssertionError("API was not called")
        (args, kwargs) = self.mock_request.call_args
        try:
            return kwargs[kwarg]
        except KeyError:
            pass

        try:
            # positional arg? This is the order of requests.Session.request params:
            pos = (
                "method",
                "url",
                "params",
                "data",
                "headers",
                "cookies",
                "files",
                "auth",
                "timeout",
                "allow_redirects",
                "proxies",
                "hooks",
                "stream",
                "verify",
                "cert",
                "json",
            ).index(kwarg)
            return args[pos]
        except (ValueError, IndexError):
            pass

        if required:
            self.fail("API was called without required arg '%s'" % kwarg)
        return None

    def get_api_call_params(self, required=True):
        """Returns the query params sent to the mock ESP API."""
        return self.get_api_call_arg("params", required)

    def get_api_call_data(self, required=True):
        """Returns the raw data sent to the mock ESP API."""
        return self.get_api_call_arg("data", required)

    def get_api_call_json(self, required=True):
        """Returns the data sent to the mock ESP API, json-parsed"""
        # could be either the data param (as json str)
        # or the json param (needing formatting)
        value = self.get_api_call_arg("data", required=False)
        if value is not None:
            return json.loads(value)
        else:
            return self.get_api_call_arg("json", required)

    def get_api_call_headers(self, required=True):
        """Returns the headers sent to the mock ESP API"""
        return self.get_api_call_arg("headers", required)

    def get_api_call_files(self, required=True):
        """Returns the files sent to the mock ESP API"""
        return self.get_api_call_arg("files", required)

    def get_api_call_auth(self, required=True):
        """Returns the auth sent to the mock ESP API"""
        return self.get_api_call_arg("auth", required)

    def get_api_prepared_request(self):
        """Returns the PreparedRequest that would have been sent"""
        (args, kwargs) = self.mock_request.call_args
        kwargs.pop("timeout", None)  # Session-only param
        request = requests.Request(**kwargs)
        return request.prepare()

    def assert_esp_not_called(self, msg=None):
        if self.mock_request.called:
            raise AssertionError(msg or "ESP API was called and shouldn't have been")


class SessionSharingTestCases(RequestsBackendMockAPITestCase):
    """Common test cases for requests backend connection sharing.

    Instantiate for each ESP by:
    - subclassing
    - adding or overriding any tests as appropriate
    """

    def __init__(self, methodName="runTest"):
        if self.__class__ is SessionSharingTestCases:
            # don't run these tests on the abstract base implementation
            methodName = "runNoTestsInBaseClass"
        super().__init__(methodName)

    def runNoTestsInBaseClass(self):
        pass

    def setUp(self):
        super().setUp()
        self.patch_close = patch("requests.Session.close", autospec=True)
        self.mock_close = self.patch_close.start()
        self.addCleanup(self.patch_close.stop)

    def test_connection_sharing(self):
        """RequestsBackend reuses one requests session when sending multiple messages"""
        datatuple = (
            ("Subject 1", "Body 1", "from@example.com", ["to@example.com"]),
            ("Subject 2", "Body 2", "from@example.com", ["to@example.com"]),
        )
        mail.send_mass_mail(datatuple)
        self.assertEqual(self.mock_request.call_count, 2)
        session1 = self.mock_request.call_args_list[0][0]  # arg[0] (self) is session
        session2 = self.mock_request.call_args_list[1][0]
        self.assertEqual(session1, session2)
        self.assertEqual(self.mock_close.call_count, 1)

    def test_caller_managed_connections(self):
        """Calling code can created long-lived connection that it opens and closes"""
        connection = mail.get_connection()
        connection.open()
        mail.send_mail(
            "Subject 1",
            "body",
            "from@example.com",
            ["to@example.com"],
            connection=connection,
        )
        session1 = self.mock_request.call_args[0]
        self.assertEqual(self.mock_close.call_count, 0)  # shouldn't be closed yet

        mail.send_mail(
            "Subject 2",
            "body",
            "from@example.com",
            ["to@example.com"],
            connection=connection,
        )
        self.assertEqual(self.mock_close.call_count, 0)  # still shouldn't be closed
        session2 = self.mock_request.call_args[0]
        self.assertEqual(session1, session2)  # should have reused same session

        connection.close()
        self.assertEqual(self.mock_close.call_count, 1)

    def test_session_closed_after_exception(self):
        self.set_mock_response(status_code=500)
        with self.assertRaises(AnymailAPIError):
            mail.send_mail("Subject", "Message", "from@example.com", ["to@example.com"])
        self.assertEqual(self.mock_close.call_count, 1)

    def test_session_closed_after_fail_silently_exception(self):
        self.set_mock_response(status_code=500)
        sent = mail.send_mail(
            "Subject",
            "Message",
            "from@example.com",
            ["to@example.com"],
            fail_silently=True,
        )
        self.assertEqual(sent, 0)
        self.assertEqual(self.mock_close.call_count, 1)

    def test_caller_managed_session_closed_after_exception(self):
        connection = mail.get_connection()
        connection.open()
        self.set_mock_response(status_code=500)
        with self.assertRaises(AnymailAPIError):
            mail.send_mail(
                "Subject",
                "Message",
                "from@example.com",
                ["to@example.com"],
                connection=connection,
            )
        self.assertEqual(self.mock_close.call_count, 0)  # wait for us to close it

        connection.close()
        self.assertEqual(self.mock_close.call_count, 1)