1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
from six import BytesIO, text_type
from six.moves.urllib.parse import urlparse, urlencode, urlunparse
import copy
import json
import zlib
from .util import CaseInsensitiveDict
def replace_headers(request, replacements):
"""
Replace headers in request according to replacements. The replacements
should be a list of (key, value) pairs where the value can be any of:
1. A simple replacement string value.
2. None to remove the given header.
3. A callable which accepts (key, value, request) and returns a string
value or None.
"""
new_headers = request.headers.copy()
for k, rv in replacements:
if k in new_headers:
ov = new_headers.pop(k)
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
new_headers[k] = rv
request.headers = new_headers
return request
def remove_headers(request, headers_to_remove):
"""
Wrap replace_headers() for API backward compatibility.
"""
replacements = [(k, None) for k in headers_to_remove]
return replace_headers(request, replacements)
def replace_query_parameters(request, replacements):
"""
Replace query parameters in request according to replacements. The
replacements should be a list of (key, value) pairs where the value can be
any of:
1. A simple replacement string value.
2. None to remove the given header.
3. A callable which accepts (key, value, request) and returns a string
value or None.
"""
query = request.query
new_query = []
replacements = dict(replacements)
for k, ov in query:
if k not in replacements:
new_query.append((k, ov))
else:
rv = replacements[k]
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
new_query.append((k, rv))
uri_parts = list(urlparse(request.uri))
uri_parts[4] = urlencode(new_query)
request.uri = urlunparse(uri_parts)
return request
def remove_query_parameters(request, query_parameters_to_remove):
"""
Wrap replace_query_parameters() for API backward compatibility.
"""
replacements = [(k, None) for k in query_parameters_to_remove]
return replace_query_parameters(request, replacements)
def replace_post_data_parameters(request, replacements):
"""
Replace post data in request--either form data or json--according to
replacements. The replacements should be a list of (key, value) pairs where
the value can be any of:
1. A simple replacement string value.
2. None to remove the given header.
3. A callable which accepts (key, value, request) and returns a string
value or None.
"""
if not request.body:
# Nothing to replace
return request
replacements = dict(replacements)
if request.method == "POST" and not isinstance(request.body, BytesIO):
if request.headers.get("Content-Type") == "application/json":
json_data = json.loads(request.body.decode("utf-8"))
for k, rv in replacements.items():
if k in json_data:
ov = json_data.pop(k)
if callable(rv):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
json_data[k] = rv
request.body = json.dumps(json_data).encode("utf-8")
else:
if isinstance(request.body, text_type):
request.body = request.body.encode("utf-8")
splits = [p.partition(b"=") for p in request.body.split(b"&")]
new_splits = []
for k, sep, ov in splits:
if sep is None:
new_splits.append((k, sep, ov))
else:
rk = k.decode("utf-8")
if rk not in replacements:
new_splits.append((k, sep, ov))
else:
rv = replacements[rk]
if callable(rv):
rv = rv(key=rk, value=ov.decode("utf-8"), request=request)
if rv is not None:
new_splits.append((k, sep, rv.encode("utf-8")))
request.body = b"&".join(k if sep is None else b"".join([k, sep, v]) for k, sep, v in new_splits)
return request
def remove_post_data_parameters(request, post_data_parameters_to_remove):
"""
Wrap replace_post_data_parameters() for API backward compatibility.
"""
replacements = [(k, None) for k in post_data_parameters_to_remove]
return replace_post_data_parameters(request, replacements)
def decode_response(response):
"""
If the response is compressed with gzip or deflate:
1. decompress the response body
2. delete the content-encoding header
3. update content-length header to decompressed length
"""
def is_compressed(headers):
encoding = headers.get("content-encoding", [])
return encoding and encoding[0] in ("gzip", "deflate")
def decompress_body(body, encoding):
"""Returns decompressed body according to encoding using zlib.
to (de-)compress gzip format, use wbits = zlib.MAX_WBITS | 16
"""
if encoding == "gzip":
return zlib.decompress(body, zlib.MAX_WBITS | 16)
else: # encoding == 'deflate'
return zlib.decompress(body)
# Deepcopy here in case `headers` contain objects that could
# be mutated by a shallow copy and corrupt the real response.
response = copy.deepcopy(response)
headers = CaseInsensitiveDict(response["headers"])
if is_compressed(headers):
encoding = headers["content-encoding"][0]
headers["content-encoding"].remove(encoding)
if not headers["content-encoding"]:
del headers["content-encoding"]
new_body = decompress_body(response["body"]["string"], encoding)
response["body"]["string"] = new_body
headers["content-length"] = [str(len(new_body))]
response["headers"] = dict(headers)
return response
|