1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
import asyncio
import logging
from datetime import timedelta
import dateutil.parser
from botocore.compat import total_seconds
from botocore.exceptions import ClientError, TokenRetrievalError
from botocore.tokens import (
DeferredRefreshableToken,
FrozenAuthToken,
SSOTokenProvider,
TokenProviderChain,
_utc_now,
)
logger = logging.getLogger(__name__)
def create_token_resolver(session):
providers = [
AioSSOTokenProvider(session),
]
return TokenProviderChain(providers=providers)
class AioDeferredRefreshableToken(DeferredRefreshableToken):
def __init__(
self, method, refresh_using, time_fetcher=_utc_now
): # noqa: E501, lgtm [py/missing-call-to-init]
self._time_fetcher = time_fetcher
self._refresh_using = refresh_using
self.method = method
# The frozen token is protected by this lock
self._refresh_lock = asyncio.Lock()
self._frozen_token = None
self._next_refresh = None
async def get_frozen_token(self):
await self._refresh()
return self._frozen_token
async def _refresh(self):
# If we don't need to refresh just return
refresh_type = self._should_refresh()
if not refresh_type:
return None
# Block for refresh if we're in the mandatory refresh window
block_for_refresh = refresh_type == "mandatory"
if block_for_refresh or not self._refresh_lock.locked():
async with self._refresh_lock:
await self._protected_refresh()
async def _protected_refresh(self):
# This should only be called after acquiring the refresh lock
# Another task may have already refreshed, double check refresh
refresh_type = self._should_refresh()
if not refresh_type:
return None
try:
now = self._time_fetcher()
self._next_refresh = now + timedelta(seconds=self._attempt_timeout)
self._frozen_token = await self._refresh_using()
except Exception:
logger.warning(
"Refreshing token failed during the %s refresh period.",
refresh_type,
exc_info=True,
)
if refresh_type == "mandatory":
# This refresh was mandatory, error must be propagated back
raise
if self._is_expired():
# Fresh credentials should never be expired
raise TokenRetrievalError(
provider=self.method,
error_msg="Token has expired and refresh failed",
)
class AioSSOTokenProvider(SSOTokenProvider):
async def _attempt_create_token(self, token):
async with self._client as client:
response = await client.create_token(
grantType=self._GRANT_TYPE,
clientId=token["clientId"],
clientSecret=token["clientSecret"],
refreshToken=token["refreshToken"],
)
expires_in = timedelta(seconds=response["expiresIn"])
new_token = {
"startUrl": self._sso_config["sso_start_url"],
"region": self._sso_config["sso_region"],
"accessToken": response["accessToken"],
"expiresAt": self._now() + expires_in,
# Cache the registration alongside the token
"clientId": token["clientId"],
"clientSecret": token["clientSecret"],
"registrationExpiresAt": token["registrationExpiresAt"],
}
if "refreshToken" in response:
new_token["refreshToken"] = response["refreshToken"]
logger.info("SSO Token refresh succeeded")
return new_token
async def _refresh_access_token(self, token):
keys = (
"refreshToken",
"clientId",
"clientSecret",
"registrationExpiresAt",
)
missing_keys = [k for k in keys if k not in token]
if missing_keys:
msg = f"Unable to refresh SSO token: missing keys: {missing_keys}"
logger.info(msg)
return None
expiry = dateutil.parser.parse(token["registrationExpiresAt"])
if total_seconds(expiry - self._now()) <= 0:
logger.info(f"SSO token registration expired at {expiry}")
return None
try:
return await self._attempt_create_token(token)
except ClientError:
logger.warning("SSO token refresh attempt failed", exc_info=True)
return None
async def _refresher(self):
start_url = self._sso_config["sso_start_url"]
session_name = self._sso_config["session_name"]
logger.info(f"Loading cached SSO token for {session_name}")
token_dict = self._token_loader(start_url, session_name=session_name)
expiration = dateutil.parser.parse(token_dict["expiresAt"])
logger.debug(f"Cached SSO token expires at {expiration}")
remaining = total_seconds(expiration - self._now())
if remaining < self._REFRESH_WINDOW:
new_token_dict = await self._refresh_access_token(token_dict)
if new_token_dict is not None:
token_dict = new_token_dict
expiration = token_dict["expiresAt"]
self._token_loader.save_token(
start_url, token_dict, session_name=session_name
)
return FrozenAuthToken(
token_dict["accessToken"], expiration=expiration
)
def load_token(self):
if self._sso_config is None:
return None
return AioDeferredRefreshableToken(
self.METHOD, self._refresher, time_fetcher=self._now
)
|