File: proxy_testcase_async.py

package info (click to toggle)
python-azure 20260113%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 786,404 kB
  • sloc: python: 6,519,100; ansic: 804; javascript: 287; sh: 204; makefile: 198; xml: 109
file content (185 lines) | stat: -rw-r--r-- 8,278 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import urllib.parse as url_parse

from azure.core.exceptions import ResourceNotFoundError
from azure.core.pipeline.policies import ContentDecodePolicy
from azure.core.pipeline.transport import AioHttpTransport

try:
    import httpx

    AsyncHTTPXTransport = httpx.AsyncHTTPTransport
except ImportError:
    httpx = None
    AsyncHTTPXTransport = None

from ..helpers import is_live_and_not_recording, trim_kwargs_from_test_function
from ..proxy_testcase import (
    RecordedTransport,
    _transform_args,
    _transform_httpx_args,
    get_test_id,
    start_record_or_playback,
    restore_httpx_response_url,
    stop_record_or_playback,
)


def recorded_by_proxy_async(*transports):
    """
    Decorator for recording and playing back test proxy sessions in async tests.

    Args:
        *transports: Which transport(s) to record. Pass one or more comma separated RecordedTransport enum values.
            - No args (default): Record AioHttpTransport.send calls (azure.core).
            - RecordedTransport.AZURE_CORE: Record AioHttpTransport.send calls. Same as the default above.
            - RecordedTransport.HTTPX: Record AsyncHTTPXTransport.handle_async_request calls.
            - RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX: Record both transports.

    Usages:

      from devtools_testutils.aio import recorded_by_proxy_async
      from devtools_testutils import RecordedTransport

      # If your test uses azure.core only network calls (default)
      @recorded_by_proxy_async
      async def test(...): ...

      # Explicitly enable azure.core recordings only (equivalent to the above)
      @recorded_by_proxy_async(RecordedTransport.AZURE_CORE)
      async def test(...): ...

      # If your test uses httpx only for network calls
      @recorded_by_proxy_async(RecordedTransport.HTTPX)
      async def test(...): ...

      # If your test uses both azure.core and httpx for network calls
      @recorded_by_proxy_async(RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX)
      async def test(...): ...
    """

    # Bare decorator usage: @recorded_by_proxy_async
    if len(transports) == 1 and callable(transports[0]):
        test_func = transports[0]
        transport_list = [(AioHttpTransport, "send")]
        return _make_proxy_decorator_async(transport_list)(test_func)

    # Parameterized decorator usage: @recorded_by_proxy_async(...)
    # Determine which transports to use
    transport_list = []

    # If no transports specified, default to azure.core
    transport_set = set(transports) if transports else {RecordedTransport.AZURE_CORE}

    # Add transports based on what's in the set
    for transport in transport_set:
        if transport == RecordedTransport.AZURE_CORE or (
            isinstance(transport, str) and transport == RecordedTransport.AZURE_CORE.value
        ):
            transport_list.append((AioHttpTransport, "send"))
        elif transport == RecordedTransport.HTTPX or (
            isinstance(transport, str) and transport == RecordedTransport.HTTPX.value
        ):
            if AsyncHTTPXTransport is not None:
                transport_list.append((AsyncHTTPXTransport, "handle_async_request"))

    # If still no transports, fall back to azure.core
    if not transport_list:
        transport_list = [(AioHttpTransport, "send")]

    # Return a decorator function that will be applied to the test function
    return lambda test_func: _make_proxy_decorator_async(transport_list)(test_func)


def _make_proxy_decorator_async(transports):
    def _decorator(test_func):
        async def record_wrap(*args, **kwargs):
            # ---- your existing trimming/early-exit logic ----
            trimmed_kwargs = {k: v for k, v in kwargs.items()}
            trim_kwargs_from_test_function(test_func, trimmed_kwargs)

            if is_live_and_not_recording():
                return await test_func(*args, **trimmed_kwargs)

            test_id = get_test_id()
            recording_id, variables = start_record_or_playback(test_id)

            # Build a wrapper factory so each patched method closes over its own original
            def make_combined_call(original_transport_func, is_httpx=False):
                async def combined_call(*call_args, **call_kwargs):
                    if is_httpx:
                        adjusted_args, adjusted_kwargs = _transform_httpx_args(recording_id, *call_args, **call_kwargs)
                        result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
                        restore_httpx_response_url(result)
                    else:
                        adjusted_args, adjusted_kwargs = _transform_args(recording_id, *call_args, **call_kwargs)
                        result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
                        # rewrite request.url to the original upstream for LROs, etc.
                        parsed_result = url_parse.urlparse(result.request.url)
                        upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
                        upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
                        original_target = parsed_result._replace(**upstream_uri_dict).geturl()
                        result.request.url = original_target
                    return result

                return combined_call

            # Patch multiple transports and ensure restoration
            test_variables = None
            test_run = False
            originals = []
            # monkeypatch all requested transports
            for owner, name in transports:
                original = getattr(owner, name)
                # Check if this is an httpx transport by comparing with httpx transport classes
                is_httpx_transport = (AsyncHTTPXTransport is not None and owner is AsyncHTTPXTransport) or (
                    httpx is not None and owner.__module__.startswith("httpx")
                )
                setattr(owner, name, make_combined_call(original, is_httpx=is_httpx_transport))
                originals.append((owner, name, original))

            try:
                try:
                    test_variables = await test_func(*args, variables=variables, **trimmed_kwargs)
                    test_run = True
                except TypeError as error:
                    if "unexpected keyword argument" in str(error) and "variables" in str(error):
                        logger = logging.getLogger()
                        logger.info(
                            "This test can't accept variables as input. "
                            "Accept `**kwargs` and/or a `variables` parameter to use recorded variables."
                        )
                    else:
                        raise

                if not test_run:
                    test_variables = await test_func(*args, **trimmed_kwargs)

            except ResourceNotFoundError as error:
                error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
                troubleshoot = (
                    "Playback failure -- for help resolving, see https://aka.ms/azsdk/python/test-proxy/troubleshoot."
                )
                message = error_body.get("message") or error_body.get("Message")
                error_with_message = ResourceNotFoundError(
                    message=f"{troubleshoot} Error details:\n{message}",
                    response=error.response,
                )
                raise error_with_message from error

            finally:
                # restore in reverse order
                for owner, name, original in reversed(originals):
                    setattr(owner, name, original)
                stop_record_or_playback(test_id, recording_id, test_variables)

            return test_variables

        return record_wrap

    return _decorator