from threading import Lock
from hashlib import sha256

from .individual_cache import _IndividualCache as IndividualCache
from .individual_cache import _ExpiringMapping as ExpiringMapping
from .oauth2cli.http import Response
from .exceptions import MsalServiceError


# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"


def _get_headers(response):
    # MSAL's HttpResponse did not have headers until 1.23.0
    # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
    # This helper ensures graceful degradation to {} without exception
    return getattr(response, "headers", {})


class RetryAfterParser(object):
    FIELD_NAME_LOWER = "Retry-After".lower()
    def __init__(self, default_value=None):
        self._default_value = 5 if default_value is None else default_value

    def parse(self, *, result, **ignored):
        """Return seconds to throttle"""
        response = result
        lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
        if not (response.status_code == 429 or response.status_code >= 500
                or self.FIELD_NAME_LOWER in lowercase_headers):
            return 0  # Quick exit
        retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
        try:
            # AAD's retry_after uses integer format only
            # https://stackoverflow.microsoft.com/questions/264931/264932
            delay_seconds = int(retry_after)
        except ValueError:
            delay_seconds = self._default_value
        return min(3600, delay_seconds)


def _extract_data(kwargs, key, default=None):
    data = kwargs.get("data", {})  # data is usually a dict, but occasionally a string
    return data.get(key) if isinstance(data, dict) else default


class NormalizedResponse(Response):
    """A http response with the shape defined in Response,
    but contains only the data we will store in cache.
    """
    def __init__(self, raw_response):
        super().__init__()
        self.status_code = raw_response.status_code
        self.text = raw_response.text
        self.headers = {
            k.lower(): v for k, v in _get_headers(raw_response).items()
            # Attempted storing only a small set of headers (such as Retry-After),
            # but it tends to lead to missing information (such as WWW-Authenticate).
            # So we store all headers, which are expected to contain only public info,
            # because we throttle only error responses and public responses.
        }

    ## Note: Don't use the following line,
    ## because when being pickled, it will indirectly pickle the whole raw_response
    # self.raise_for_status = raw_response.raise_for_status
    def raise_for_status(self):
        if self.status_code >= 400:
            raise MsalServiceError(
                "HTTP Error: {}".format(self.status_code),
                error=None, error_description=None,  #  Historically required, keeping them for now
            )


class ThrottledHttpClientBase(object):
    """Throttle the given http_client by storing and retrieving data from cache.

    This base exists so that:
    1. These base post() and get() will return a NormalizedResponse
    2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.

    Subclasses shall only need to dynamically decorate their post() and get() methods
    in their __init__() method.
    """
    def __init__(self, http_client, *, http_cache=None):
        self.http_client = http_client.http_client if isinstance(
            # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
            http_client, ThrottledHttpClientBase) else http_client
        self._expiring_mapping = ExpiringMapping(  # It will automatically clean up
            mapping=http_cache if http_cache is not None else {},
            capacity=1024,  # To prevent cache blowing up especially for CCA
            lock=Lock(),  # TODO: This should ideally also allow customization
            )

    def post(self, *args, **kwargs):
        return NormalizedResponse(self.http_client.post(*args, **kwargs))

    def get(self, *args, **kwargs):
        return NormalizedResponse(self.http_client.get(*args, **kwargs))

    def close(self):
        return self.http_client.close()

    @staticmethod
    def _hash(raw):
        return sha256(repr(raw).encode("utf-8")).hexdigest()


class ThrottledHttpClient(ThrottledHttpClientBase):
    """A throttled http client that is used by MSAL's non-managed identity clients."""
    def __init__(self, *args, default_throttle_time=None, **kwargs):
        """Decorate self.post() and self.get() dynamically"""
        super(ThrottledHttpClient, self).__init__(*args, **kwargs)
        self.post = IndividualCache(
            # Internal specs requires throttling on at least token endpoint,
            # here we have a generic patch for POST on all endpoints.
            mapping=self._expiring_mapping,
            key_maker=lambda func, args, kwargs:
                "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
                    args[0],  # It is the url, typically containing authority and tenant
                    _extract_data(kwargs, "client_id"),  # Per internal specs
                    _extract_data(kwargs, "scope"),  # Per internal specs
                    self._hash(
                        # The followings are all approximations of the "account" concept
                        # to support per-account throttling.
                        # TODO: We may want to disable it for confidential client, though
                        _extract_data(kwargs, "refresh_token",  # "account" during refresh
                            _extract_data(kwargs, "code",  # "account" of auth code grant
                                _extract_data(kwargs, "username")))),  # "account" of ROPC
                    ),
            expires_in=RetryAfterParser(default_throttle_time or 5).parse,
            )(self.post)

        self.post = IndividualCache(  # It covers the "UI required cache"
            mapping=self._expiring_mapping,
            key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
                args[0],  # It is the url, typically containing authority and tenant
                self._hash(
                    # Here we use literally all parameters, even those short-lived
                    # parameters containing timestamps (WS-Trust or POP assertion),
                    # because they will automatically be cleaned up by ExpiringMapping.
                    #
                    # Furthermore, there is no need to implement
                    # "interactive requests would reset the cache",
                    # because acquire_token_silent()'s would be automatically unblocked
                    # due to token cache layer operates on top of http cache layer.
                    #
                    # And, acquire_token_silent(..., force_refresh=True) will NOT
                    # bypass http cache, because there is no real gain from that.
                    # We won't bother implement it, nor do we want to encourage
                    # acquire_token_silent(..., force_refresh=True) pattern.
                    str(kwargs.get("params")) + str(kwargs.get("data"))),
                ),
            expires_in=lambda result=None, kwargs=None, **ignored:
                60
                if result.status_code == 400
                    # Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
                    # because they are the ones defined in OAuth2
                    # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
                    # Other 4xx errors might have different requirements e.g.
                    # "407 Proxy auth required" would need a key including http headers.
                and not(  # Exclude Device Flow whose retry is expected and regulated
                    isinstance(kwargs.get("data"), dict)
                    and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
                    )
                and RetryAfterParser.FIELD_NAME_LOWER not in set(  # Otherwise leave it to the Retry-After decorator
                    h.lower() for h in _get_headers(result))
                else 0,
            )(self.post)

        self.get = IndividualCache(  # Typically those discovery GETs
            mapping=self._expiring_mapping,
            key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
                args[0],  # It is the url, sometimes containing inline params
                self._hash(kwargs.get("params", "")),
                ),
            expires_in=lambda result=None, **ignored:
                3600*24 if 200 <= result.status_code < 300 else 0,
            )(self.get)
