from httpx import AsyncClient, Client
try:
    from httpx._config import UNSET
except ImportError:
    UNSET = None
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from authlib.oauth2 import OAuth2Error
from .utils import extract_client_kwargs
from .oauth2_client import OAuth2Auth

__all__ = ['AsyncAssertionClient']


class AsyncAssertionClient(_AssertionClient, AsyncClient):
    token_auth_class = OAuth2Auth
    JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
    ASSERTION_METHODS = {
        JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
    }
    DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

    def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
                 claims=None, token_placement='header', scope=None, **kwargs):

        client_kwargs = extract_client_kwargs(kwargs)
        AsyncClient.__init__(self, **client_kwargs)

        _AssertionClient.__init__(
            self, session=None,
            token_endpoint=token_endpoint, issuer=issuer, subject=subject,
            audience=audience, grant_type=grant_type, claims=claims,
            token_placement=token_placement, scope=scope, **kwargs
        )

    async def request(self, method, url, withhold_token=False, auth=None, **kwargs):
        """Send request with auto refresh token feature."""
        if not withhold_token and auth is UNSET:
            if not self.token or self.token.is_expired():
                await self.refresh_token()

            auth = self.token_auth
        return await super(AsyncAssertionClient, self).request(
            method, url, auth=auth, **kwargs)

    async def _refresh_token(self, data):
        resp = await self.request(
            'POST', self.token_endpoint, data=data, withhold_token=True)

        token = resp.json()
        if 'error' in token:
            raise OAuth2Error(
                error=token['error'],
                description=token.get('error_description')
            )
        self.token = token
        return self.token


class AssertionClient(_AssertionClient, Client):
    token_auth_class = OAuth2Auth
    JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
    ASSERTION_METHODS = {
        JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
    }
    DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

    def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
                 claims=None, token_placement='header', scope=None, **kwargs):

        client_kwargs = extract_client_kwargs(kwargs)
        Client.__init__(self, **client_kwargs)

        _AssertionClient.__init__(
            self, session=self,
            token_endpoint=token_endpoint, issuer=issuer, subject=subject,
            audience=audience, grant_type=grant_type, claims=claims,
            token_placement=token_placement, scope=scope, **kwargs
        )

    def request(self, method, url, withhold_token=False, auth=None, **kwargs):
        """Send request with auto refresh token feature."""
        if not withhold_token and auth is UNSET:
            if not self.token or self.token.is_expired():
                self.refresh_token()

            auth = self.token_auth
        return super(AssertionClient, self).request(
            method, url, auth=auth, **kwargs)
