1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
|
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------
from contextlib import contextmanager
import copy
import inspect
import json
import os
import os.path
import time
import vcr
import re
from vcr.filters import decode_response
import zlib
from .common_extendedtestcase import ExtendedTestCase
class TestMode(object):
none = 'None' # this will be for unit test, no need for any recordings
playback = 'Playback'
record = 'Record'
live = 'Live'
@staticmethod
def is_playback(mode):
return mode.lower() == TestMode.playback.lower()
@staticmethod
def need_recordingfile(mode):
mode_lower = mode.lower()
return mode_lower == TestMode.playback.lower() or mode_lower == TestMode.record.lower()
@staticmethod
def need_real_credentials(mode):
mode_lower = mode.lower()
return mode_lower == TestMode.live.lower() or mode_lower == TestMode.record.lower()
class RecordingTestCase(ExtendedTestCase):
def __init__(self, *args, **kwargs):
super(RecordingTestCase, self).__init__(*args, **kwargs)
def setUp(self):
super(RecordingTestCase, self).setUp()
self.init_test_mode()
# example of qualified test name:
# test_mgmt_network.test_public_ip_addresses
_, filename = os.path.split(inspect.getsourcefile(type(self)))
name, _ = os.path.splitext(filename)
self.qualified_test_name = '{0}.{1}'.format(
name,
self._testMethodName,
)
def init_test_mode(self):
try:
path = os.path.join(self.working_folder, 'testsettings_local.json')
with open(path) as testsettings_local_file:
test_settings = json.load(testsettings_local_file)
self.test_mode = test_settings['mode']
except:
pass
if getattr(self, 'test_mode', None) is None:
self.test_mode = TestMode.playback
def sleep(self, seconds):
if not self.is_playback():
time.sleep(seconds)
def is_playback(self):
return TestMode.is_playback(self.test_mode)
def recording(self):
if TestMode.need_recordingfile(self.test_mode):
cassette_name = '{0}.yaml'.format(self.qualified_test_name)
my_vcr = vcr.VCR(
before_record_request = self._scrub_sensitive_request_info,
before_record_response = self._scrub_sensitive_response_info,
record_mode = 'none' if TestMode.is_playback(self.test_mode) else 'all',
)
self.assertIsNotNone(self.working_folder)
return my_vcr.use_cassette(
os.path.join(self.working_folder, 'recordings', cassette_name),
filter_headers=['authorization'],
)
else:
return self._nop_context_manager()
def get_resource_name(self, name):
# Append a suffix to the name, based on the fully qualified test name
# We use a checksum of the test name so that each test gets different
# resource names, but each test will get the same name on repeat runs,
# which is needed for playback.
# Most resource names have a length limit, so we use a crc32
self.checksum = zlib.adler32(self.qualified_test_name.encode()) & 0xffffffff
name = '{}{}'.format(name, hex(self.checksum)[2:])
if name.endswith('L'):
name = name[:-1]
return name
def _scrub_sensitive_request_info(self, request):
# WARNING: For some strange url parsing reason, sometimes url have '//':
# - Python 2.7 for 2.7/3.3/3.4 (...Microsoft.Compute//availabilitySets...)
# - Python 3.5 (...Microsoft.Compute/availabilitySets...)
# I don't know why 3.5 has one / and 2.7-3.4 two /
request.uri = re.sub('(?<!:)//','/',request.uri)
if not TestMode.is_playback(self.test_mode):
request.uri = self._scrub(request.uri)
if request.body is not None:
request.body = self._scrub(request.body)
return request
def _scrub_sensitive_response_info(self, response):
if not TestMode.is_playback(self.test_mode):
# We need to make a copy because vcr doesn't make one for us.
# Without this, changing the contents of the dicts would change
# the contents returned to the caller - not just the contents
# getting saved to disk. That would be a problem with headers
# such as 'location', often used in the request uri of a
# subsequent service call.
response = copy.deepcopy(response)
# decode_response is supposed to do a copy, but do it bad
# https://github.com/kevin1024/vcrpy/issues/264
response = decode_response(response)
headers = response.get('headers')
if headers:
def internal_scrub(key, val):
if key.lower() == 'retry-after':
return '0'
return self._scrub(val)
for name, val in headers.items():
if isinstance(val, list):
for i, e in enumerate(val):
val[i] = internal_scrub(name, e)
else:
headers[name] = internal_scrub(name, val)
body = response.get('body')
if body:
body_str = body.get('string')
if body_str:
response['body']['string'] = self._scrub(body_str)
return response
def _scrub(self, val):
return val
def _scrub_using_dict(self, val, real_to_fake_dict):
replacements = list(real_to_fake_dict.keys())
# if we have 'val1' and 'val10', we want 'val10' to be replaced first
replacements.sort(reverse=True)
for real_val in replacements:
if real_val:
fake_val = real_to_fake_dict[real_val]
if real_val != fake_val:
if isinstance(val, bytes):
val = val.replace(real_val.encode(), fake_val.encode())
else:
val = val.replace(real_val, fake_val)
return val
@contextmanager
def _nop_context_manager(self):
yield
def record(test):
def recording_test(self):
with self.recording():
test(self)
recording_test.__name__ = test.__name__
return recording_test
|