File: throttled_http_client.py

package info (click to toggle)
microsoft-authentication-library-for-python 1.34.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,320 kB
  • sloc: python: 8,613; xml: 2,783; sh: 27; makefile: 19
file content (179 lines) | stat: -rw-r--r-- 9,023 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from threading import Lock
from hashlib import sha256

from .individual_cache import _IndividualCache as IndividualCache
from .individual_cache import _ExpiringMapping as ExpiringMapping
from .oauth2cli.http import Response
from .exceptions import MsalServiceError


# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"


def _get_headers(response):
    # MSAL's HttpResponse did not have headers until 1.23.0
    # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
    # This helper ensures graceful degradation to {} without exception
    return getattr(response, "headers", {})


class RetryAfterParser(object):
    FIELD_NAME_LOWER = "Retry-After".lower()
    def __init__(self, default_value=None):
        self._default_value = 5 if default_value is None else default_value

    def parse(self, *, result, **ignored):
        """Return seconds to throttle"""
        response = result
        lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
        if not (response.status_code == 429 or response.status_code >= 500
                or self.FIELD_NAME_LOWER in lowercase_headers):
            return 0  # Quick exit
        retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
        try:
            # AAD's retry_after uses integer format only
            # https://stackoverflow.microsoft.com/questions/264931/264932
            delay_seconds = int(retry_after)
        except ValueError:
            delay_seconds = self._default_value
        return min(3600, delay_seconds)


def _extract_data(kwargs, key, default=None):
    data = kwargs.get("data", {})  # data is usually a dict, but occasionally a string
    return data.get(key) if isinstance(data, dict) else default


class NormalizedResponse(Response):
    """A http response with the shape defined in Response,
    but contains only the data we will store in cache.
    """
    def __init__(self, raw_response):
        super().__init__()
        self.status_code = raw_response.status_code
        self.text = raw_response.text
        self.headers = {
            k.lower(): v for k, v in _get_headers(raw_response).items()
            # Attempted storing only a small set of headers (such as Retry-After),
            # but it tends to lead to missing information (such as WWW-Authenticate).
            # So we store all headers, which are expected to contain only public info,
            # because we throttle only error responses and public responses.
        }

    ## Note: Don't use the following line,
    ## because when being pickled, it will indirectly pickle the whole raw_response
    # self.raise_for_status = raw_response.raise_for_status
    def raise_for_status(self):
        if self.status_code >= 400:
            raise MsalServiceError(
                "HTTP Error: {}".format(self.status_code),
                error=None, error_description=None,  #  Historically required, keeping them for now
            )


class ThrottledHttpClientBase(object):
    """Throttle the given http_client by storing and retrieving data from cache.

    This base exists so that:
    1. These base post() and get() will return a NormalizedResponse
    2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.

    Subclasses shall only need to dynamically decorate their post() and get() methods
    in their __init__() method.
    """
    def __init__(self, http_client, *, http_cache=None):
        self.http_client = http_client.http_client if isinstance(
            # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
            http_client, ThrottledHttpClientBase) else http_client
        self._expiring_mapping = ExpiringMapping(  # It will automatically clean up
            mapping=http_cache if http_cache is not None else {},
            capacity=1024,  # To prevent cache blowing up especially for CCA
            lock=Lock(),  # TODO: This should ideally also allow customization
            )

    def post(self, *args, **kwargs):
        return NormalizedResponse(self.http_client.post(*args, **kwargs))

    def get(self, *args, **kwargs):
        return NormalizedResponse(self.http_client.get(*args, **kwargs))

    def close(self):
        return self.http_client.close()

    @staticmethod
    def _hash(raw):
        return sha256(repr(raw).encode("utf-8")).hexdigest()


class ThrottledHttpClient(ThrottledHttpClientBase):
    """A throttled http client that is used by MSAL's non-managed identity clients."""
    def __init__(self, *args, default_throttle_time=None, **kwargs):
        """Decorate self.post() and self.get() dynamically"""
        super(ThrottledHttpClient, self).__init__(*args, **kwargs)
        self.post = IndividualCache(
            # Internal specs requires throttling on at least token endpoint,
            # here we have a generic patch for POST on all endpoints.
            mapping=self._expiring_mapping,
            key_maker=lambda func, args, kwargs:
                "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
                    args[0],  # It is the url, typically containing authority and tenant
                    _extract_data(kwargs, "client_id"),  # Per internal specs
                    _extract_data(kwargs, "scope"),  # Per internal specs
                    self._hash(
                        # The followings are all approximations of the "account" concept
                        # to support per-account throttling.
                        # TODO: We may want to disable it for confidential client, though
                        _extract_data(kwargs, "refresh_token",  # "account" during refresh
                            _extract_data(kwargs, "code",  # "account" of auth code grant
                                _extract_data(kwargs, "username")))),  # "account" of ROPC
                    ),
            expires_in=RetryAfterParser(default_throttle_time or 5).parse,
            )(self.post)

        self.post = IndividualCache(  # It covers the "UI required cache"
            mapping=self._expiring_mapping,
            key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
                args[0],  # It is the url, typically containing authority and tenant
                self._hash(
                    # Here we use literally all parameters, even those short-lived
                    # parameters containing timestamps (WS-Trust or POP assertion),
                    # because they will automatically be cleaned up by ExpiringMapping.
                    #
                    # Furthermore, there is no need to implement
                    # "interactive requests would reset the cache",
                    # because acquire_token_silent()'s would be automatically unblocked
                    # due to token cache layer operates on top of http cache layer.
                    #
                    # And, acquire_token_silent(..., force_refresh=True) will NOT
                    # bypass http cache, because there is no real gain from that.
                    # We won't bother implement it, nor do we want to encourage
                    # acquire_token_silent(..., force_refresh=True) pattern.
                    str(kwargs.get("params")) + str(kwargs.get("data"))),
                ),
            expires_in=lambda result=None, kwargs=None, **ignored:
                60
                if result.status_code == 400
                    # Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
                    # because they are the ones defined in OAuth2
                    # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
                    # Other 4xx errors might have different requirements e.g.
                    # "407 Proxy auth required" would need a key including http headers.
                and not(  # Exclude Device Flow whose retry is expected and regulated
                    isinstance(kwargs.get("data"), dict)
                    and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
                    )
                and RetryAfterParser.FIELD_NAME_LOWER not in set(  # Otherwise leave it to the Retry-After decorator
                    h.lower() for h in _get_headers(result))
                else 0,
            )(self.post)

        self.get = IndividualCache(  # Typically those discovery GETs
            mapping=self._expiring_mapping,
            key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
                args[0],  # It is the url, sometimes containing inline params
                self._hash(kwargs.get("params", "")),
                ),
            expires_in=lambda result=None, **ignored:
                3600*24 if 200 <= result.status_code < 300 else 0,
            )(self.get)