File: proxy_fixtures.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (323 lines) | stat: -rw-r--r-- 15,589 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from inspect import iscoroutinefunction
import logging
import os
from typing import TYPE_CHECKING
import urllib.parse as url_parse

import pytest

from azure.core.exceptions import ResourceNotFoundError
from azure.core.pipeline.policies import ContentDecodePolicy

# the functions we patch
try:
    from azure.core.pipeline.transport import RequestsTransport
except:
    pass

from .helpers import get_test_id, is_live, is_live_and_not_recording
from .proxy_testcase import start_record_or_playback, stop_record_or_playback, transform_request
from .proxy_startup import test_proxy
from .sanitizers import add_batch_sanitizers, add_general_string_sanitizer, Sanitizer

if TYPE_CHECKING:
    from typing import Any, Callable, Dict, Optional, Tuple
    from pytest import FixtureRequest

# In pytest-asyncio>=0.19.0 async fixtures need to be marked with pytest_asyncio.fixture, not pytest.fixture, by default
# pytest_asyncio.fixture is only recently available (~0.17.0), so we need to account for an import error
try:
    from pytest_asyncio import fixture as async_fixture
except ImportError:
    from pytest import fixture as async_fixture


_LOGGER = logging.getLogger()


class EnvironmentVariableSanitizer:
    def __init__(self) -> None:
        self._fake_values = {}

    def sanitize(self, variable: str, value: str) -> str:
        """Registers a sanitizer that replaces the value of the specified environment variable with the provided value.

        :param str variable: Name of the environment variable to sanitize.
        :param str value: Value to sanitize the environment variable's value with.

        :returns: The real value of `variable` in live mode, or the sanitized value in playback.
        """
        self._fake_values[variable] = value
        real_value = os.getenv(variable)
        if real_value:
            add_general_string_sanitizer(target=real_value, value=value, function_scoped=True)
        else:
            _LOGGER.info(f"No value for {variable} was found, so a sanitizer could not be registered for the variable.")

        return real_value if is_live() else value

    def sanitize_batch(self, variables: "Dict[str, str]") -> "Dict[str, str]":
        """Registers sanitizers that replace the values of multiple environment variables with the provided values.

        :param variables: A dictionary mapping environment variable names to values they should be sanitized with.
            For example: {"SERICE_CLIENT_ID": "fake_client_id", "SERVICE_ENDPOINT": "https://fake-endpoint.azure.net"}

        :returns: A dictionary mapping environment variables to their real values in live mode, or their sanitized
            values in playback.
        """
        real_values = {}
        sanitizers = {Sanitizer.GENERAL_STRING: []}

        for variable in variables:
            self._fake_values[variable] = variables[variable]
            real_value = os.getenv(variable)
            real_values[variable] = real_value
            # If the variable has a value to be sanitized, add a general string sanitizer for it to our batch request
            if real_value:
                sanitizers[Sanitizer.GENERAL_STRING].append({"target": real_value, "value": variables[variable]})

        add_batch_sanitizers(sanitizers)
        return real_values if is_live() else self._fake_values

    def get(self, variable: str) -> str:
        """Returns the value of the specified environment variable in live mode, or the sanitized value in playback.

        :param str variable: Name of the environment variable to fetch the value of.

        :returns: The real value of `variable` in live mode, or the sanitized value in playback.
        """
        return os.getenv(variable) if is_live() else self._fake_values.get(variable)


class VariableRecorder:
    def __init__(self, variables: "Dict[str, str]") -> None:
        self.variables = variables

    def get_or_record(self, variable: str, default: str) -> str:
        """Returns the recorded value of `variable`, or records and returns `default` as the value for `variable`.

        In recording mode, `get_or_record("a", "b")` will record "b" for the value of the variable `a` and return "b".
        In playback, it will return the recorded value of `a`. This is an analogue of a Python dictionary's `setdefault`
        method: https://docs.python.org/library/stdtypes.html#dict.setdefault.

        :param str variable: The name of the variable to search the value of, or record a value for.
        :param str default: The variable value to record.

        :returns: str
        """
        if not isinstance(default, str):
            raise ValueError('"default" must be a string. The test proxy cannot record non-string variable values.')
        return self.variables.setdefault(variable, default)


@pytest.fixture(scope="session")
def environment_variables(test_proxy: None) -> EnvironmentVariableSanitizer:
    """Fixture that returns an EnvironmentVariableSanitizer for convenient environment variable fetching and sanitizing.

    This fixture is session-scoped, so a single instance of EnvironmentVariableSanitizer is shared across all
    tests using this fixture in the test session.

    :param test_proxy: The fixture responsible for starting up the test proxy server.
    :type test_proxy: None

    :returns: An EnvironmentVariableSanitizer object. Calling:
        - `sanitize(a, b)` will sanitize the value of environment variable `a` with value `b`
        - `sanitize_batch(dict)` will sanitize the values of all variables in dictionary `dict`
        - `get(a)` will return the value of environment variable `a` in the current context (live or playback mode)
        See the definition of EnvironmentVariableSanitizer in
        https://github.com/Azure/azure-sdk-for-python/blob/main/tools/azure-sdk-tools/devtools_testutils/proxy_fixtures.py
        for more details.
    """
    return EnvironmentVariableSanitizer()


@async_fixture
async def recorded_test(test_proxy: None, request: "FixtureRequest") -> "Dict[str, Any]":
    """Fixture that redirects network requests to target the azure-sdk-tools test proxy.

    Use with recorded tests. For more details and usage examples, refer to
    https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/test_proxy_migration_guide.md.

    :param test_proxy: The fixture responsible for starting up the test proxy server.
    :type test_proxy: None
    :param request: The built-in `request` fixture.
    :type request: ~pytest.FixtureRequest

    :yields: A dictionary containing information relevant to the currently executing test.
        If the current test session is live but recording is disabled, yields None.
    """
    if is_live_and_not_recording():
        yield {"variables": {}}  # yield an empty set of variables since recordings aren't used
    else:
        test_id, recording_id, variables = start_proxy_session()

        # True if the function requesting the fixture is an async test
        if iscoroutinefunction(request._pyfuncitem.function):
            original_transport_func = await redirect_async_traffic(recording_id)
            yield {"variables": variables}  # yield relevant test info and allow tests to run
            restore_async_traffic(original_transport_func, request)
        else:
            original_transport_func = redirect_traffic(recording_id)
            yield {"variables": variables}  # yield relevant test info and allow tests to run
            restore_traffic(original_transport_func, request)

        stop_record_or_playback(test_id, recording_id, variables)


@pytest.fixture
def variable_recorder(recorded_test: "Dict[str, Any]") -> VariableRecorder:
    """Fixture that invokes the `recorded_test` fixture and returns a dictionary of recorded test variables.

    :param recorded_test: The fixture responsible for redirecting network traffic to target the test proxy.
        This should return a dictionary containing information about the current test -- in particular, the variables
        that were recorded with the test.
    :type recorded_test: Dict[str, Any]

    :returns: A VariableRecorder object. Calling `get_or_record(a, b)` on this object will return the recorded value of
        `a` in playback mode, or record the value `b` in recording mode. See the definition of VariableRecorder in
        https://github.com/Azure/azure-sdk-for-python/blob/main/tools/azure-sdk-tools/devtools_testutils/proxy_fixtures.py
        for more details.
    """
    return VariableRecorder(recorded_test["variables"])


# ----------HELPERS----------


def start_proxy_session() -> "Tuple[str, str, Dict[str, str]]":
    """Begins a playback or recording session and returns the current test ID, recording ID, and recorded variables.

    :returns: A tuple, (a, b, c), where a is the test ID, b is the recording ID, and c is the `variables` dictionary
        that maps test variables to string values. If no variable dictionary was stored when the test was recorded, c is
        an empty dictionary.
    """
    test_id = get_test_id()
    recording_id, variables = start_record_or_playback(test_id)
    return (test_id, recording_id, variables)


async def redirect_async_traffic(recording_id: str) -> "Callable":
    """Redirects asynchronous network requests to target the test proxy.

    :param str recording_id: Recording ID of the currently executing test.

    :returns: The original transport function used by the currently executing test.
    """
    from azure.core.pipeline.transport import AioHttpTransport

    original_transport_func = AioHttpTransport.send

    def transform_args(*args, **kwargs):
        copied_positional_args = list(args)
        request = copied_positional_args[1]

        transform_request(request, recording_id)

        return tuple(copied_positional_args), kwargs

    async def combined_call(*args, **kwargs):
        adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs)
        result = await original_transport_func(*adjusted_args, **adjusted_kwargs)

        # make the x-recording-upstream-base-uri the URL of the request
        # this makes the request look like it was made to the original endpoint instead of to the proxy
        # without this, things like LROPollers can get broken by polling the wrong endpoint
        parsed_result = url_parse.urlparse(result.request.url)
        upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
        upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
        original_target = parsed_result._replace(**upstream_uri_dict).geturl()

        result.request.url = original_target
        return result

    AioHttpTransport.send = combined_call
    return original_transport_func


def redirect_traffic(recording_id: str) -> "Callable":
    """Redirects network requests to target the test proxy.

    :param str recording_id: Recording ID of the currently executing test.

    :returns: The original transport function used by the currently executing test.
    """
    original_transport_func = RequestsTransport.send

    def transform_args(*args, **kwargs):
        copied_positional_args = list(args)
        http_request = copied_positional_args[1]

        transform_request(http_request, recording_id)

        return tuple(copied_positional_args), kwargs

    def combined_call(*args, **kwargs):
        adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs)
        result = original_transport_func(*adjusted_args, **adjusted_kwargs)

        # make the x-recording-upstream-base-uri the URL of the request
        # this makes the request look like it was made to the original endpoint instead of to the proxy
        # without this, things like LROPollers can get broken by polling the wrong endpoint
        parsed_result = url_parse.urlparse(result.request.url)
        upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
        upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
        original_target = parsed_result._replace(**upstream_uri_dict).geturl()

        result.request.url = original_target
        return result

    RequestsTransport.send = combined_call
    return original_transport_func


def restore_async_traffic(original_transport_func: "Callable", request: "FixtureRequest") -> None:
    """Resets asynchronous network traffic to no longer target the test proxy.

    :param original_transport_func: The original transport function used by the currently executing test.
    :type original_transport_func: Callable
    :param request: The built-in `request` pytest fixture.
    :type request: ~pytest.FixtureRequest
    """
    from azure.core.pipeline.transport import AioHttpTransport

    AioHttpTransport.send = original_transport_func  # test finished running -- tear down

    if hasattr(request.node, "test_error"):
        # Exceptions are logged here instead of being raised because of how pytest handles error raising from inside
        # fixtures and hooks. Raising from a fixture raises an error in addition to the test failure report, and the
        # test proxy error is logged before the test failure output (making it difficult to find in pytest output).
        # Raising from a hook isn't allowed, and produces an internal error that disrupts test execution.
        # ResourceNotFoundErrors during playback indicate a recording mismatch
        error = request.node.test_error
        if isinstance(error, ResourceNotFoundError):
            error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
            message = error_body.get("message") or error_body.get("Message")
            _LOGGER.error(f"\n\n-----Test proxy playback error:-----\n\n{message}")


def restore_traffic(original_transport_func: "Callable", request: "FixtureRequest") -> None:
    """Resets network traffic to no longer target the test proxy.

    :param original_transport_func: The original transport function used by the currently executing test.
    :type original_transport_func: Callable
    :param request: The built-in `request` pytest fixture.
    :type request: ~pytest.FixtureRequest
    """
    RequestsTransport.send = original_transport_func  # test finished running -- tear down

    if hasattr(request.node, "test_error"):
        # Exceptions are logged here instead of being raised because of how pytest handles error raising from inside
        # fixtures and hooks. Raising from a fixture raises an error in addition to the test failure report, and the
        # test proxy error is logged before the test failure output (making it difficult to find in pytest output).
        # Raising from a hook isn't allowed, and produces an internal error that disrupts test execution.
        # ResourceNotFoundErrors during playback indicate a recording mismatch
        error = request.node.test_error
        if isinstance(error, ResourceNotFoundError):
            error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
            message = error_body.get("message") or error_body.get("Message")
            _LOGGER.error(f"\n\n-----Test proxy playback error:-----\n\n{message}")