1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
|
import json
from io import BytesIO
from unittest.mock import patch
import requests
from django.core import mail
from django.test import SimpleTestCase
from anymail.exceptions import AnymailAPIError
from .utils import AnymailTestMixin
UNSET = object()
class RequestsBackendMockAPITestCase(AnymailTestMixin, SimpleTestCase):
"""TestCase that mocks API calls through requests"""
DEFAULT_RAW_RESPONSE = b"""{"subclass": "should override"}"""
DEFAULT_CONTENT_TYPE = None # e.g., "application/json"
DEFAULT_STATUS_CODE = 200 # most APIs use '200 OK' for success
class MockResponse(requests.Response):
"""requests.request return value mock sufficient for testing"""
def __init__(
self,
status_code=200,
raw=b"RESPONSE",
content_type=None,
encoding="utf-8",
reason=None,
test_case=None,
):
super().__init__()
self.status_code = status_code
self.encoding = encoding
self.reason = reason or ("OK" if 200 <= status_code < 300 else "ERROR")
self.raw = BytesIO(raw)
if content_type is not None:
self.headers["Content-Type"] = content_type
self.test_case = test_case
@property
def url(self):
return self.test_case.get_api_call_arg("url", required=False)
@url.setter
def url(self, url):
if url is not None:
raise ValueError("MockResponse can't handle url assignment")
def setUp(self):
super().setUp()
self.patch_request = patch("requests.Session.request", autospec=True)
self.mock_request = self.patch_request.start()
self.addCleanup(self.patch_request.stop)
self.set_mock_response()
def set_mock_response(
self,
status_code=UNSET,
raw=UNSET,
json_data=UNSET,
encoding="utf-8",
content_type=UNSET,
reason=None,
):
if status_code is UNSET:
status_code = self.DEFAULT_STATUS_CODE
if json_data is not UNSET:
assert raw is UNSET, "provide json_data or raw, not both"
raw = json.dumps(json_data).encode(encoding)
if content_type is UNSET:
content_type = "application/json"
if raw is UNSET:
raw = self.DEFAULT_RAW_RESPONSE
if content_type is UNSET:
content_type = self.DEFAULT_CONTENT_TYPE
mock_response = self.MockResponse(
status_code,
raw=raw,
content_type=content_type,
encoding=encoding,
reason=reason,
test_case=self,
)
self.mock_request.return_value = mock_response
return mock_response
def assert_esp_called(self, url, method="POST"):
"""Verifies the (mock) ESP API was called on endpoint.
url can be partial, and is just checked against the end of the url requested
"""
# This assumes the last (or only) call to requests.Session.request
# is the API call of interest.
if self.mock_request.call_args is None:
raise AssertionError("No ESP API was called")
if method is not None:
actual_method = self.get_api_call_arg("method")
if actual_method != method:
self.fail(
"API was not called using %s. (%s was used instead.)"
% (method, actual_method)
)
if url is not None:
actual_url = self.get_api_call_arg("url")
if not actual_url.endswith(url):
self.fail(
"API was not called at %s\n(It was called at %s)"
% (url, actual_url)
)
def get_api_call_arg(self, kwarg, required=True):
"""Returns an argument passed to the mock ESP API.
Fails test if API wasn't called.
"""
if self.mock_request.call_args is None:
raise AssertionError("API was not called")
(args, kwargs) = self.mock_request.call_args
try:
return kwargs[kwarg]
except KeyError:
pass
try:
# positional arg? This is the order of requests.Session.request params:
pos = (
"method",
"url",
"params",
"data",
"headers",
"cookies",
"files",
"auth",
"timeout",
"allow_redirects",
"proxies",
"hooks",
"stream",
"verify",
"cert",
"json",
).index(kwarg)
return args[pos]
except (ValueError, IndexError):
pass
if required:
self.fail("API was called without required arg '%s'" % kwarg)
return None
def get_api_call_params(self, required=True):
"""Returns the query params sent to the mock ESP API."""
return self.get_api_call_arg("params", required)
def get_api_call_data(self, required=True):
"""Returns the raw data sent to the mock ESP API."""
return self.get_api_call_arg("data", required)
def get_api_call_json(self, required=True):
"""Returns the data sent to the mock ESP API, json-parsed"""
# could be either the data param (as json str)
# or the json param (needing formatting)
value = self.get_api_call_arg("data", required=False)
if value is not None:
return json.loads(value)
else:
return self.get_api_call_arg("json", required)
def get_api_call_headers(self, required=True):
"""Returns the headers sent to the mock ESP API"""
return self.get_api_call_arg("headers", required)
def get_api_call_files(self, required=True):
"""Returns the files sent to the mock ESP API"""
return self.get_api_call_arg("files", required)
def get_api_call_auth(self, required=True):
"""Returns the auth sent to the mock ESP API"""
return self.get_api_call_arg("auth", required)
def get_api_prepared_request(self):
"""Returns the PreparedRequest that would have been sent"""
(args, kwargs) = self.mock_request.call_args
kwargs.pop("timeout", None) # Session-only param
request = requests.Request(**kwargs)
return request.prepare()
def assert_esp_not_called(self, msg=None):
if self.mock_request.called:
raise AssertionError(msg or "ESP API was called and shouldn't have been")
class SessionSharingTestCases(RequestsBackendMockAPITestCase):
"""Common test cases for requests backend connection sharing.
Instantiate for each ESP by:
- subclassing
- adding or overriding any tests as appropriate
"""
def __init__(self, methodName="runTest"):
if self.__class__ is SessionSharingTestCases:
# don't run these tests on the abstract base implementation
methodName = "runNoTestsInBaseClass"
super().__init__(methodName)
def runNoTestsInBaseClass(self):
pass
def setUp(self):
super().setUp()
self.patch_close = patch("requests.Session.close", autospec=True)
self.mock_close = self.patch_close.start()
self.addCleanup(self.patch_close.stop)
def test_connection_sharing(self):
"""RequestsBackend reuses one requests session when sending multiple messages"""
datatuple = (
("Subject 1", "Body 1", "from@example.com", ["to@example.com"]),
("Subject 2", "Body 2", "from@example.com", ["to@example.com"]),
)
mail.send_mass_mail(datatuple)
self.assertEqual(self.mock_request.call_count, 2)
session1 = self.mock_request.call_args_list[0][0] # arg[0] (self) is session
session2 = self.mock_request.call_args_list[1][0]
self.assertEqual(session1, session2)
self.assertEqual(self.mock_close.call_count, 1)
def test_caller_managed_connections(self):
"""Calling code can created long-lived connection that it opens and closes"""
connection = mail.get_connection()
connection.open()
mail.send_mail(
"Subject 1",
"body",
"from@example.com",
["to@example.com"],
connection=connection,
)
session1 = self.mock_request.call_args[0]
self.assertEqual(self.mock_close.call_count, 0) # shouldn't be closed yet
mail.send_mail(
"Subject 2",
"body",
"from@example.com",
["to@example.com"],
connection=connection,
)
self.assertEqual(self.mock_close.call_count, 0) # still shouldn't be closed
session2 = self.mock_request.call_args[0]
self.assertEqual(session1, session2) # should have reused same session
connection.close()
self.assertEqual(self.mock_close.call_count, 1)
def test_session_closed_after_exception(self):
self.set_mock_response(status_code=500)
with self.assertRaises(AnymailAPIError):
mail.send_mail("Subject", "Message", "from@example.com", ["to@example.com"])
self.assertEqual(self.mock_close.call_count, 1)
def test_session_closed_after_fail_silently_exception(self):
self.set_mock_response(status_code=500)
sent = mail.send_mail(
"Subject",
"Message",
"from@example.com",
["to@example.com"],
fail_silently=True,
)
self.assertEqual(sent, 0)
self.assertEqual(self.mock_close.call_count, 1)
def test_caller_managed_session_closed_after_exception(self):
connection = mail.get_connection()
connection.open()
self.set_mock_response(status_code=500)
with self.assertRaises(AnymailAPIError):
mail.send_mail(
"Subject",
"Message",
"from@example.com",
["to@example.com"],
connection=connection,
)
self.assertEqual(self.mock_close.call_count, 0) # wait for us to close it
connection.close()
self.assertEqual(self.mock_close.call_count, 1)
|