1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
from threading import Lock
from hashlib import sha256
from .individual_cache import _IndividualCache as IndividualCache
from .individual_cache import _ExpiringMapping as ExpiringMapping
from .oauth2cli.http import Response
from .exceptions import MsalServiceError
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
def _get_headers(response):
# MSAL's HttpResponse did not have headers until 1.23.0
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
# This helper ensures graceful degradation to {} without exception
return getattr(response, "headers", {})
class RetryAfterParser(object):
FIELD_NAME_LOWER = "Retry-After".lower()
def __init__(self, default_value=None):
self._default_value = 5 if default_value is None else default_value
def parse(self, *, result, **ignored):
"""Return seconds to throttle"""
response = result
lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
if not (response.status_code == 429 or response.status_code >= 500
or self.FIELD_NAME_LOWER in lowercase_headers):
return 0 # Quick exit
retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
try:
# AAD's retry_after uses integer format only
# https://stackoverflow.microsoft.com/questions/264931/264932
delay_seconds = int(retry_after)
except ValueError:
delay_seconds = self._default_value
return min(3600, delay_seconds)
def _extract_data(kwargs, key, default=None):
data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
return data.get(key) if isinstance(data, dict) else default
class NormalizedResponse(Response):
"""A http response with the shape defined in Response,
but contains only the data we will store in cache.
"""
def __init__(self, raw_response):
super().__init__()
self.status_code = raw_response.status_code
self.text = raw_response.text
self.headers = {
k.lower(): v for k, v in _get_headers(raw_response).items()
# Attempted storing only a small set of headers (such as Retry-After),
# but it tends to lead to missing information (such as WWW-Authenticate).
# So we store all headers, which are expected to contain only public info,
# because we throttle only error responses and public responses.
}
## Note: Don't use the following line,
## because when being pickled, it will indirectly pickle the whole raw_response
# self.raise_for_status = raw_response.raise_for_status
def raise_for_status(self):
if self.status_code >= 400:
raise MsalServiceError(
"HTTP Error: {}".format(self.status_code),
error=None, error_description=None, # Historically required, keeping them for now
)
class ThrottledHttpClientBase(object):
"""Throttle the given http_client by storing and retrieving data from cache.
This base exists so that:
1. These base post() and get() will return a NormalizedResponse
2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
Subclasses shall only need to dynamically decorate their post() and get() methods
in their __init__() method.
"""
def __init__(self, http_client, *, http_cache=None):
self.http_client = http_client.http_client if isinstance(
# If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
http_client, ThrottledHttpClientBase) else http_client
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
mapping=http_cache if http_cache is not None else {},
capacity=1024, # To prevent cache blowing up especially for CCA
lock=Lock(), # TODO: This should ideally also allow customization
)
def post(self, *args, **kwargs):
return NormalizedResponse(self.http_client.post(*args, **kwargs))
def get(self, *args, **kwargs):
return NormalizedResponse(self.http_client.get(*args, **kwargs))
def close(self):
return self.http_client.close()
@staticmethod
def _hash(raw):
return sha256(repr(raw).encode("utf-8")).hexdigest()
class ThrottledHttpClient(ThrottledHttpClientBase):
"""A throttled http client that is used by MSAL's non-managed identity clients."""
def __init__(self, *args, default_throttle_time=None, **kwargs):
"""Decorate self.post() and self.get() dynamically"""
super(ThrottledHttpClient, self).__init__(*args, **kwargs)
self.post = IndividualCache(
# Internal specs requires throttling on at least token endpoint,
# here we have a generic patch for POST on all endpoints.
mapping=self._expiring_mapping,
key_maker=lambda func, args, kwargs:
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
args[0], # It is the url, typically containing authority and tenant
_extract_data(kwargs, "client_id"), # Per internal specs
_extract_data(kwargs, "scope"), # Per internal specs
self._hash(
# The followings are all approximations of the "account" concept
# to support per-account throttling.
# TODO: We may want to disable it for confidential client, though
_extract_data(kwargs, "refresh_token", # "account" during refresh
_extract_data(kwargs, "code", # "account" of auth code grant
_extract_data(kwargs, "username")))), # "account" of ROPC
),
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
)(self.post)
self.post = IndividualCache( # It covers the "UI required cache"
mapping=self._expiring_mapping,
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
args[0], # It is the url, typically containing authority and tenant
self._hash(
# Here we use literally all parameters, even those short-lived
# parameters containing timestamps (WS-Trust or POP assertion),
# because they will automatically be cleaned up by ExpiringMapping.
#
# Furthermore, there is no need to implement
# "interactive requests would reset the cache",
# because acquire_token_silent()'s would be automatically unblocked
# due to token cache layer operates on top of http cache layer.
#
# And, acquire_token_silent(..., force_refresh=True) will NOT
# bypass http cache, because there is no real gain from that.
# We won't bother implement it, nor do we want to encourage
# acquire_token_silent(..., force_refresh=True) pattern.
str(kwargs.get("params")) + str(kwargs.get("data"))),
),
expires_in=lambda result=None, kwargs=None, **ignored:
60
if result.status_code == 400
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
# because they are the ones defined in OAuth2
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
# Other 4xx errors might have different requirements e.g.
# "407 Proxy auth required" would need a key including http headers.
and not( # Exclude Device Flow whose retry is expected and regulated
isinstance(kwargs.get("data"), dict)
and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
)
and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
h.lower() for h in _get_headers(result))
else 0,
)(self.post)
self.get = IndividualCache( # Typically those discovery GETs
mapping=self._expiring_mapping,
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
args[0], # It is the url, sometimes containing inline params
self._hash(kwargs.get("params", "")),
),
expires_in=lambda result=None, **ignored:
3600*24 if 200 <= result.status_code < 300 else 0,
)(self.get)
|