1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import asyncio
import functools
from azure_devtools.scenario_tests.patches import mock_in_unit_test
from azure_devtools.scenario_tests.utilities import trim_kwargs_from_test_function
from azure.core.credentials import AccessToken
from .testcase import StorageTestCase
LOGGING_FORMAT = '%(asctime)s %(name)-20s %(levelname)-5s %(message)s'
class AsyncFakeTokenCredential(object):
"""Protocol for classes able to provide OAuth tokens.
:param str scopes: Lets you specify the type of access needed.
"""
def __init__(self):
self.token = AccessToken("YOU SHALL NOT PASS", 0)
async def get_token(self, *args):
return self.token
def patch_play_responses(unit_test):
"""Fixes a bug affecting blob tests by applying https://github.com/kevin1024/vcrpy/pull/511 to vcrpy 3.0.0"""
try:
from vcr.stubs.aiohttp_stubs import _serialize_headers, build_response, Request, URL
except ImportError:
# return a do-nothing patch when importing from vcr fails
return lambda _: None
def fixed_play_responses(cassette, vcr_request):
history = []
vcr_response = cassette.play_response(vcr_request)
response = build_response(vcr_request, vcr_response, history)
while 300 <= response.status <= 399:
if "location" not in response.headers:
break
next_url = URL(response.url).with_path(response.headers["location"])
vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers))
vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]
history.append(response)
vcr_response = cassette.play_response(vcr_request)
response = build_response(vcr_request, vcr_response, history)
return response
return mock_in_unit_test(unit_test, "vcr.stubs.aiohttp_stubs.play_responses", fixed_play_responses)
class AsyncStorageTestCase(StorageTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.replay_patches.append(patch_play_responses)
@staticmethod
def await_prepared_test(test_fn):
"""Synchronous wrapper for async test methods. Used to avoid making changes
upstream to AbstractPreparer (which doesn't await the functions it wraps)
"""
@functools.wraps(test_fn)
def run(test_class_instance, *args, **kwargs):
trim_kwargs_from_test_function(test_fn, kwargs)
loop = asyncio.get_event_loop()
return loop.run_until_complete(test_fn(test_class_instance, **kwargs))
return run
def generate_oauth_token(self):
if self.is_live:
from azure.identity.aio import ClientSecretCredential
return ClientSecretCredential(
self.get_settings_value("TENANT_ID"),
self.get_settings_value("CLIENT_ID"),
self.get_settings_value("CLIENT_SECRET"),
)
return self.generate_fake_token()
def generate_fake_token(self):
return AsyncFakeTokenCredential()
|