1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
# Copyright 2020 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Requires Python 2.6+ and Openssl 1.0+
#
import contextlib
from azurelinuxagent.common.protocol.wire import WireProtocol
from azurelinuxagent.common.utils import restutil
from tests.lib.tools import patch
from tests.lib import wire_protocol_data
@contextlib.contextmanager
def mock_wire_protocol(mock_wire_data_file, http_get_handler=None, http_post_handler=None, http_put_handler=None, do_not_mock=lambda method, url: False, fail_on_unknown_request=True, save_to_history=False, detect_protocol=True):
"""
Creates a WireProtocol object that handles requests to the WireServer, the Host GA Plugin, and some requests to storage (requests that provide mock data
in wire_protocol_data.py).
The data returned by those requests is read from the files specified by 'mock_wire_data_file' (which must follow the structure of the data
files defined in tests/protocol/wire_protocol_data.py).
The caller can also provide handler functions for specific HTTP methods using the http_*_handler arguments. The return value of the handler
function is interpreted similarly to the "return_value" argument of patch(): if it is an exception the exception is raised or, if it is
any object other than None, the value is returned by the mock. If the handler function returns None the call is handled using the mock
wireserver data or passed to the original to restutil.http_request.
The 'do_not_mock' lambda can be used to skip the mocks for specific requests; if the lambda returns True, the mocks won't be applied and the
original common.utils.restutil.http_request will be invoked instead.
The 'save_to_history' parameter is passed thru in the call to WireProtocol.detect().
The returned protocol object maintains a list of "tracked" urls. When a handler function returns a value than is not None the url for the
request is automatically added to the tracked list. The handler function can add other items to this list using the track_url() method on
the mock.
The return value of this function is an instance of WireProtocol augmented with these properties/methods:
* mock_wire_data - the WireProtocolData constructed from the mock_wire_data_file parameter.
* start() - starts the patchers for http_request and CryptUtil
* stop() - stops the patchers
* track_url(url) - adds the given item to the list of tracked urls.
* get_tracked_urls() - returns the list of tracked urls.
NOTE: This function patches common.utils.restutil.http_request and common.protocol.wire.CryptUtil; you need to be aware of this if your
tests patch those methods or others in the call stack (e.g. restutil.get, resutil._http_request, etc)
"""
tracked_urls = []
# use a helper function to keep the HTTP handlers (they need to be modified by set_http_handlers() and
# Python 2.* does not support nonlocal declarations)
def http_handlers(get, post, put):
http_handlers.get = get
http_handlers.post = post
http_handlers.put = put
del tracked_urls[:]
http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler)
#
# function used to patch restutil.http_request
#
original_http_request = restutil.http_request
def http_request(method, url, data, timeout, **kwargs):
# call the original resutil.http_request if the request should be mocked
if protocol.do_not_mock(method, url):
return original_http_request(method, url, data, timeout, **kwargs)
# if there is a handler for the request, use it
handler = None
if method == 'GET':
handler = http_handlers.get
elif method == 'POST':
handler = http_handlers.post
elif method == 'PUT':
handler = http_handlers.put
if handler is not None:
if method == 'GET':
return_value = handler(url, **kwargs)
else:
return_value = handler(url, data, **kwargs)
if return_value is not None:
tracked_urls.append(url)
if isinstance(return_value, Exception):
raise return_value
return return_value
# if the request was not handled try to use the mock wireserver data
try:
if method == 'GET':
return protocol.mock_wire_data.mock_http_get(url, **kwargs)
if method == 'POST':
return protocol.mock_wire_data.mock_http_post(url, data, **kwargs)
if method == 'PUT':
return protocol.mock_wire_data.mock_http_put(url, data, **kwargs)
except NotImplementedError:
pass
# if there was not a response for the request then fail it or call the original resutil.http_request
if fail_on_unknown_request:
raise ValueError('Unknown HTTP request: {0} [{1}]'.format(url, method))
return original_http_request(method, url, data, timeout, **kwargs)
#
# functions to start/stop the mocks
#
def start():
patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request)
patched.start()
start.http_request_patch = patched
patched = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util)
patched.start()
start.crypt_util_patch = patched
start.http_request_patch = None
start.crypt_util_patch = None
def stop():
if start.crypt_util_patch is not None:
start.crypt_util_patch.stop()
if start.http_request_patch is not None:
start.http_request_patch.stop()
#
# create the protocol object
#
protocol = WireProtocol(restutil.KNOWN_WIRESERVER_IP)
protocol.mock_wire_data = wire_protocol_data.WireProtocolData(mock_wire_data_file)
protocol.start = start
protocol.stop = stop
protocol.track_url = lambda url: tracked_urls.append(url) # pylint: disable=unnecessary-lambda
protocol.get_tracked_urls = lambda: tracked_urls
protocol.set_http_handlers = lambda http_get_handler=None, http_post_handler=None, http_put_handler=None:\
http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler)
protocol.do_not_mock = do_not_mock
# go do it
try:
protocol.start()
if detect_protocol:
protocol.detect(save_to_history=save_to_history)
yield protocol
finally:
protocol.stop()
class MockHttpResponse:
def __init__(self, status, body=b'', headers=None, reason=None):
self.body = body
self.status = status
self.headers = [] if headers is None else headers
self.reason = reason
def read(self, *_):
return self.body
def getheaders(self):
return self.headers
|