1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import urllib.parse as url_parse
from azure.core.exceptions import ResourceNotFoundError
from azure.core.pipeline.policies import ContentDecodePolicy
from azure.core.pipeline.transport import AioHttpTransport
try:
import httpx
AsyncHTTPXTransport = httpx.AsyncHTTPTransport
except ImportError:
httpx = None
AsyncHTTPXTransport = None
from ..helpers import is_live_and_not_recording, trim_kwargs_from_test_function
from ..proxy_testcase import (
RecordedTransport,
_transform_args,
_transform_httpx_args,
get_test_id,
start_record_or_playback,
restore_httpx_response_url,
stop_record_or_playback,
)
def recorded_by_proxy_async(*transports):
"""
Decorator for recording and playing back test proxy sessions in async tests.
Args:
*transports: Which transport(s) to record. Pass one or more comma separated RecordedTransport enum values.
- No args (default): Record AioHttpTransport.send calls (azure.core).
- RecordedTransport.AZURE_CORE: Record AioHttpTransport.send calls. Same as the default above.
- RecordedTransport.HTTPX: Record AsyncHTTPXTransport.handle_async_request calls.
- RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX: Record both transports.
Usages:
from devtools_testutils.aio import recorded_by_proxy_async
from devtools_testutils import RecordedTransport
# If your test uses azure.core only network calls (default)
@recorded_by_proxy_async
async def test(...): ...
# Explicitly enable azure.core recordings only (equivalent to the above)
@recorded_by_proxy_async(RecordedTransport.AZURE_CORE)
async def test(...): ...
# If your test uses httpx only for network calls
@recorded_by_proxy_async(RecordedTransport.HTTPX)
async def test(...): ...
# If your test uses both azure.core and httpx for network calls
@recorded_by_proxy_async(RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX)
async def test(...): ...
"""
# Bare decorator usage: @recorded_by_proxy_async
if len(transports) == 1 and callable(transports[0]):
test_func = transports[0]
transport_list = [(AioHttpTransport, "send")]
return _make_proxy_decorator_async(transport_list)(test_func)
# Parameterized decorator usage: @recorded_by_proxy_async(...)
# Determine which transports to use
transport_list = []
# If no transports specified, default to azure.core
transport_set = set(transports) if transports else {RecordedTransport.AZURE_CORE}
# Add transports based on what's in the set
for transport in transport_set:
if transport == RecordedTransport.AZURE_CORE or (
isinstance(transport, str) and transport == RecordedTransport.AZURE_CORE.value
):
transport_list.append((AioHttpTransport, "send"))
elif transport == RecordedTransport.HTTPX or (
isinstance(transport, str) and transport == RecordedTransport.HTTPX.value
):
if AsyncHTTPXTransport is not None:
transport_list.append((AsyncHTTPXTransport, "handle_async_request"))
# If still no transports, fall back to azure.core
if not transport_list:
transport_list = [(AioHttpTransport, "send")]
# Return a decorator function that will be applied to the test function
return lambda test_func: _make_proxy_decorator_async(transport_list)(test_func)
def _make_proxy_decorator_async(transports):
def _decorator(test_func):
async def record_wrap(*args, **kwargs):
# ---- your existing trimming/early-exit logic ----
trimmed_kwargs = {k: v for k, v in kwargs.items()}
trim_kwargs_from_test_function(test_func, trimmed_kwargs)
if is_live_and_not_recording():
return await test_func(*args, **trimmed_kwargs)
test_id = get_test_id()
recording_id, variables = start_record_or_playback(test_id)
# Build a wrapper factory so each patched method closes over its own original
def make_combined_call(original_transport_func, is_httpx=False):
async def combined_call(*call_args, **call_kwargs):
if is_httpx:
adjusted_args, adjusted_kwargs = _transform_httpx_args(recording_id, *call_args, **call_kwargs)
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
restore_httpx_response_url(result)
else:
adjusted_args, adjusted_kwargs = _transform_args(recording_id, *call_args, **call_kwargs)
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
# rewrite request.url to the original upstream for LROs, etc.
parsed_result = url_parse.urlparse(result.request.url)
upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
original_target = parsed_result._replace(**upstream_uri_dict).geturl()
result.request.url = original_target
return result
return combined_call
# Patch multiple transports and ensure restoration
test_variables = None
test_run = False
originals = []
# monkeypatch all requested transports
for owner, name in transports:
original = getattr(owner, name)
# Check if this is an httpx transport by comparing with httpx transport classes
is_httpx_transport = (AsyncHTTPXTransport is not None and owner is AsyncHTTPXTransport) or (
httpx is not None and owner.__module__.startswith("httpx")
)
setattr(owner, name, make_combined_call(original, is_httpx=is_httpx_transport))
originals.append((owner, name, original))
try:
try:
test_variables = await test_func(*args, variables=variables, **trimmed_kwargs)
test_run = True
except TypeError as error:
if "unexpected keyword argument" in str(error) and "variables" in str(error):
logger = logging.getLogger()
logger.info(
"This test can't accept variables as input. "
"Accept `**kwargs` and/or a `variables` parameter to use recorded variables."
)
else:
raise
if not test_run:
test_variables = await test_func(*args, **trimmed_kwargs)
except ResourceNotFoundError as error:
error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
troubleshoot = (
"Playback failure -- for help resolving, see https://aka.ms/azsdk/python/test-proxy/troubleshoot."
)
message = error_body.get("message") or error_body.get("Message")
error_with_message = ResourceNotFoundError(
message=f"{troubleshoot} Error details:\n{message}",
response=error.response,
)
raise error_with_message from error
finally:
# restore in reverse order
for owner, name, original in reversed(originals):
setattr(owner, name, original)
stop_record_or_playback(test_id, recording_id, test_variables)
return test_variables
return record_wrap
return _decorator
|