File: helpers.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (102 lines) | stat: -rw-r--r-- 3,935 bytes parent folder | download | duplicates (7)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
from urllib import parse

try:
    from unittest import mock
except ImportError:  # python < 3.3
    import mock  # type: ignore


class Request:
    def __init__(
        self,
        base_url=None,
        url=None,
        authority=None,
        url_substring=None,
        method=None,
        required_headers=None,
        required_data=None,
        required_params=None,
    ):
        self.authority = authority
        self.base_url = base_url
        self.method = method
        self.url = url
        self.url_substring = url_substring
        self.required_headers = required_headers or {}
        self.required_data = required_data or {}
        self.required_params = required_params or {}

    def assert_matches(self, request):
        discrepancies = []

        def add_discrepancy(name, expected, actual):
            discrepancies.append("{}:\n\t expected: {}\n\t   actual: {}".format(name, expected, actual))

        if self.base_url and self.base_url != request.url.split("?")[0]:
            add_discrepancy("base url", self.base_url, request.url)

        if self.url and self.url != request.url:
            add_discrepancy("url", self.url, request.url)

        if self.url_substring and self.url_substring not in request.url:
            add_discrepancy("url substring", self.url_substring, request.url)

        parsed = parse.urlparse(request.url)
        if self.authority and parsed.netloc != self.authority:
            add_discrepancy("authority", self.authority, parsed.netloc)

        if self.method and request.method != self.method:
            add_discrepancy("method", self.method, request.method)

        for param, expected_value in self.required_params.items():
            actual_value = request.query.get(param)
            if actual_value != expected_value:
                add_discrepancy(param, expected_value, actual_value)

        for header, expected_value in self.required_headers.items():
            actual_value = request.headers.get(header)

            # UserAgentPolicy appends the value of $AZURE_HTTP_USER_AGENT, which is set in
            # pipelines, so we accept a user agent which merely contains the expected value
            if header.lower() == "user-agent":
                if expected_value not in actual_value:
                    add_discrepancy("user-agent", "contains " + expected_value, actual_value)
            elif actual_value != expected_value:
                add_discrepancy(header, expected_value, actual_value)

        for field, expected_value in self.required_data.items():
            actual_value = request.body.get(field)
            if actual_value != expected_value:
                add_discrepancy("form field", expected_value, actual_value)

        assert not discrepancies, "Unexpected request\n\t" + "\n\t".join(discrepancies)


def mock_response(status_code=200, headers=None, json_payload=None):
    response = mock.Mock(status_code=status_code, headers=headers or {})
    if json_payload is not None:
        response.text = lambda encoding=None: json.dumps(json_payload)
        response.headers["content-type"] = "application/json"
        response.content_type = "application/json"
    return response


def validating_transport(requests, responses):
    if len(requests) != len(responses):
        raise ValueError("each request must have one response")

    sessions = zip(requests, responses)
    sessions = (s for s in sessions)  # 2.7's zip returns a list, and nesting a generator doesn't break it for 3.x

    def validate_request(request, **kwargs):  # pylint:disable=unused-argument
        expected_request, response = next(sessions)
        expected_request.assert_matches(request)
        return response

    return mock.Mock(send=validate_request)