1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
from urllib import parse
try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore
class Request:
def __init__(
self,
base_url=None,
url=None,
authority=None,
url_substring=None,
method=None,
required_headers=None,
required_data=None,
required_params=None,
):
self.authority = authority
self.base_url = base_url
self.method = method
self.url = url
self.url_substring = url_substring
self.required_headers = required_headers or {}
self.required_data = required_data or {}
self.required_params = required_params or {}
def assert_matches(self, request):
discrepancies = []
def add_discrepancy(name, expected, actual):
discrepancies.append("{}:\n\t expected: {}\n\t actual: {}".format(name, expected, actual))
if self.base_url and self.base_url != request.url.split("?")[0]:
add_discrepancy("base url", self.base_url, request.url)
if self.url and self.url != request.url:
add_discrepancy("url", self.url, request.url)
if self.url_substring and self.url_substring not in request.url:
add_discrepancy("url substring", self.url_substring, request.url)
parsed = parse.urlparse(request.url)
if self.authority and parsed.netloc != self.authority:
add_discrepancy("authority", self.authority, parsed.netloc)
if self.method and request.method != self.method:
add_discrepancy("method", self.method, request.method)
for param, expected_value in self.required_params.items():
actual_value = request.query.get(param)
if actual_value != expected_value:
add_discrepancy(param, expected_value, actual_value)
for header, expected_value in self.required_headers.items():
actual_value = request.headers.get(header)
# UserAgentPolicy appends the value of $AZURE_HTTP_USER_AGENT, which is set in
# pipelines, so we accept a user agent which merely contains the expected value
if header.lower() == "user-agent":
if expected_value not in actual_value:
add_discrepancy("user-agent", "contains " + expected_value, actual_value)
elif actual_value != expected_value:
add_discrepancy(header, expected_value, actual_value)
for field, expected_value in self.required_data.items():
actual_value = request.body.get(field)
if actual_value != expected_value:
add_discrepancy("form field", expected_value, actual_value)
assert not discrepancies, "Unexpected request\n\t" + "\n\t".join(discrepancies)
def mock_response(status_code=200, headers=None, json_payload=None):
response = mock.Mock(status_code=status_code, headers=headers or {})
if json_payload is not None:
response.text = lambda encoding=None: json.dumps(json_payload)
response.headers["content-type"] = "application/json"
response.content_type = "application/json"
return response
def validating_transport(requests, responses):
if len(requests) != len(responses):
raise ValueError("each request must have one response")
sessions = zip(requests, responses)
sessions = (s for s in sessions) # 2.7's zip returns a list, and nesting a generator doesn't break it for 3.x
def validate_request(request, **kwargs): # pylint:disable=unused-argument
expected_request, response = next(sessions)
expected_request.assert_matches(request)
return response
return mock.Mock(send=validate_request)
|