| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 
 | # ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
from urllib import parse
try:
    from unittest import mock
except ImportError:  # python < 3.3
    import mock  # type: ignore
class Request:
    def __init__(
        self,
        base_url=None,
        url=None,
        authority=None,
        url_substring=None,
        method=None,
        required_headers=None,
        required_data=None,
        required_params=None,
    ):
        self.authority = authority
        self.base_url = base_url
        self.method = method
        self.url = url
        self.url_substring = url_substring
        self.required_headers = required_headers or {}
        self.required_data = required_data or {}
        self.required_params = required_params or {}
    def assert_matches(self, request):
        discrepancies = []
        def add_discrepancy(name, expected, actual):
            discrepancies.append("{}:\n\t expected: {}\n\t   actual: {}".format(name, expected, actual))
        if self.base_url and self.base_url != request.url.split("?")[0]:
            add_discrepancy("base url", self.base_url, request.url)
        if self.url and self.url != request.url:
            add_discrepancy("url", self.url, request.url)
        if self.url_substring and self.url_substring not in request.url:
            add_discrepancy("url substring", self.url_substring, request.url)
        parsed = parse.urlparse(request.url)
        if self.authority and parsed.netloc != self.authority:
            add_discrepancy("authority", self.authority, parsed.netloc)
        if self.method and request.method != self.method:
            add_discrepancy("method", self.method, request.method)
        for param, expected_value in self.required_params.items():
            actual_value = request.query.get(param)
            if actual_value != expected_value:
                add_discrepancy(param, expected_value, actual_value)
        for header, expected_value in self.required_headers.items():
            actual_value = request.headers.get(header)
            # UserAgentPolicy appends the value of $AZURE_HTTP_USER_AGENT, which is set in
            # pipelines, so we accept a user agent which merely contains the expected value
            if header.lower() == "user-agent":
                if expected_value not in actual_value:
                    add_discrepancy("user-agent", "contains " + expected_value, actual_value)
            elif actual_value != expected_value:
                add_discrepancy(header, expected_value, actual_value)
        for field, expected_value in self.required_data.items():
            actual_value = request.body.get(field)
            if actual_value != expected_value:
                add_discrepancy("form field", expected_value, actual_value)
        assert not discrepancies, "Unexpected request\n\t" + "\n\t".join(discrepancies)
def mock_response(status_code=200, headers=None, json_payload=None):
    response = mock.Mock(status_code=status_code, headers=headers or {})
    if json_payload is not None:
        response.text = lambda encoding=None: json.dumps(json_payload)
        response.headers["content-type"] = "application/json"
        response.content_type = "application/json"
    return response
def validating_transport(requests, responses):
    if len(requests) != len(responses):
        raise ValueError("each request must have one response")
    sessions = zip(requests, responses)
    sessions = (s for s in sessions)  # 2.7's zip returns a list, and nesting a generator doesn't break it for 3.x
    def validate_request(request, **kwargs):  # pylint:disable=unused-argument
        expected_request, response = next(sessions)
        expected_request.assert_matches(request)
        return response
    return mock.Mock(send=validate_request)
 |