File: tokens.py

package info (click to toggle)
python-aiobotocore 2.13.1-1.1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 832 kB
  • sloc: python: 10,572; makefile: 71
file content (161 lines) | stat: -rw-r--r-- 5,757 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import asyncio
import logging
from datetime import timedelta

import dateutil.parser
from botocore.compat import total_seconds
from botocore.exceptions import ClientError, TokenRetrievalError
from botocore.tokens import (
    DeferredRefreshableToken,
    FrozenAuthToken,
    SSOTokenProvider,
    TokenProviderChain,
    _utc_now,
)

logger = logging.getLogger(__name__)


def create_token_resolver(session):
    providers = [
        AioSSOTokenProvider(session),
    ]
    return TokenProviderChain(providers=providers)


class AioDeferredRefreshableToken(DeferredRefreshableToken):
    def __init__(
        self, method, refresh_using, time_fetcher=_utc_now
    ):  # noqa: E501, lgtm [py/missing-call-to-init]
        self._time_fetcher = time_fetcher
        self._refresh_using = refresh_using
        self.method = method

        # The frozen token is protected by this lock
        self._refresh_lock = asyncio.Lock()
        self._frozen_token = None
        self._next_refresh = None

    async def get_frozen_token(self):
        await self._refresh()
        return self._frozen_token

    async def _refresh(self):
        # If we don't need to refresh just return
        refresh_type = self._should_refresh()
        if not refresh_type:
            return None

        # Block for refresh if we're in the mandatory refresh window
        block_for_refresh = refresh_type == "mandatory"
        if block_for_refresh or not self._refresh_lock.locked():
            async with self._refresh_lock:
                await self._protected_refresh()

    async def _protected_refresh(self):
        # This should only be called after acquiring the refresh lock
        # Another task may have already refreshed, double check refresh
        refresh_type = self._should_refresh()
        if not refresh_type:
            return None

        try:
            now = self._time_fetcher()
            self._next_refresh = now + timedelta(seconds=self._attempt_timeout)
            self._frozen_token = await self._refresh_using()
        except Exception:
            logger.warning(
                "Refreshing token failed during the %s refresh period.",
                refresh_type,
                exc_info=True,
            )
            if refresh_type == "mandatory":
                # This refresh was mandatory, error must be propagated back
                raise

        if self._is_expired():
            # Fresh credentials should never be expired
            raise TokenRetrievalError(
                provider=self.method,
                error_msg="Token has expired and refresh failed",
            )


class AioSSOTokenProvider(SSOTokenProvider):
    async def _attempt_create_token(self, token):
        async with self._client as client:
            response = await client.create_token(
                grantType=self._GRANT_TYPE,
                clientId=token["clientId"],
                clientSecret=token["clientSecret"],
                refreshToken=token["refreshToken"],
            )
        expires_in = timedelta(seconds=response["expiresIn"])
        new_token = {
            "startUrl": self._sso_config["sso_start_url"],
            "region": self._sso_config["sso_region"],
            "accessToken": response["accessToken"],
            "expiresAt": self._now() + expires_in,
            # Cache the registration alongside the token
            "clientId": token["clientId"],
            "clientSecret": token["clientSecret"],
            "registrationExpiresAt": token["registrationExpiresAt"],
        }
        if "refreshToken" in response:
            new_token["refreshToken"] = response["refreshToken"]
        logger.info("SSO Token refresh succeeded")
        return new_token

    async def _refresh_access_token(self, token):
        keys = (
            "refreshToken",
            "clientId",
            "clientSecret",
            "registrationExpiresAt",
        )
        missing_keys = [k for k in keys if k not in token]
        if missing_keys:
            msg = f"Unable to refresh SSO token: missing keys: {missing_keys}"
            logger.info(msg)
            return None

        expiry = dateutil.parser.parse(token["registrationExpiresAt"])
        if total_seconds(expiry - self._now()) <= 0:
            logger.info(f"SSO token registration expired at {expiry}")
            return None

        try:
            return await self._attempt_create_token(token)
        except ClientError:
            logger.warning("SSO token refresh attempt failed", exc_info=True)
            return None

    async def _refresher(self):
        start_url = self._sso_config["sso_start_url"]
        session_name = self._sso_config["session_name"]
        logger.info(f"Loading cached SSO token for {session_name}")
        token_dict = self._token_loader(start_url, session_name=session_name)
        expiration = dateutil.parser.parse(token_dict["expiresAt"])
        logger.debug(f"Cached SSO token expires at {expiration}")

        remaining = total_seconds(expiration - self._now())
        if remaining < self._REFRESH_WINDOW:
            new_token_dict = await self._refresh_access_token(token_dict)
            if new_token_dict is not None:
                token_dict = new_token_dict
                expiration = token_dict["expiresAt"]
                self._token_loader.save_token(
                    start_url, token_dict, session_name=session_name
                )

        return FrozenAuthToken(
            token_dict["accessToken"], expiration=expiration
        )

    def load_token(self):
        if self._sso_config is None:
            return None

        return AioDeferredRefreshableToken(
            self.METHOD, self._refresher, time_fetcher=self._now
        )