1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
|
"""Token Verifier module"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from .. import TokenValidationError
from ..rest_async import AsyncRestClient
from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier
if TYPE_CHECKING:
from aiohttp import ClientSession
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
"""Async verifier for RSA signatures, which rely on public key certificates.
Args:
jwks_url (str): The url where the JWK set is located.
algorithm (str, optional): The expected signing algorithm. Defaults to "RS256".
"""
def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None:
super().__init__(jwks_url, algorithm)
self._fetcher = AsyncJwksFetcher(jwks_url)
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
session (aiohttp.ClientSession): The client session which should be closed
manually or within context manager.
"""
self._fetcher.set_session(session)
async def _fetch_key(self, key_id=None):
"""Request the JWKS.
Args:
key_id (str): The key's key id."""
return await self._fetcher.get_key(key_id)
async def verify_signature(self, token) -> dict[str, Any]:
"""Verifies the signature of the given JSON web token.
Args:
token (str): The JWT to get its signature verified.
Raises:
TokenValidationError: if the token cannot be decoded, the algorithm is invalid
or the token's signature doesn't match the calculated one.
"""
kid = self._get_kid(token)
secret_or_certificate = await self._fetch_key(key_id=kid)
return self._decode_jwt(token, secret_or_certificate)
class AsyncJwksFetcher(JwksFetcher):
"""Class that async fetches and holds a JSON web key set.
This class makes use of an in-memory cache. For it to work properly, define this instance once and re-use it.
Args:
jwks_url (str): The url where the JWK set is located.
cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_client = AsyncRestClient(None)
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
session (aiohttp.ClientSession): The client session which should be closed
manually or within context manager.
"""
self._async_client.set_session(session)
async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]:
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.
When not, it will perform a network request to the jwks_url to obtain a fresh result
and update the cache value with it.
Args:
force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False.
"""
if force or self._cache_expired():
self._cache_value = {}
try:
jwks = await self._async_client.get(self._jwks_url)
self._cache_jwks(jwks)
except: # noqa: E722
return self._cache_value
return self._cache_value
self._cache_is_fresh = False
return self._cache_value
async def get_key(self, key_id: str) -> RSAPublicKey:
"""Obtains the JWK associated with the given key id.
Args:
key_id (str): The id of the key to fetch.
Returns:
the JWK associated with the given key id.
Raises:
TokenValidationError: when a key with that id cannot be found
"""
keys = await self._fetch_jwks()
if keys and key_id in keys:
return keys[key_id]
if not self._cache_is_fresh:
keys = await self._fetch_jwks(force=True)
if keys and key_id in keys:
return keys[key_id]
raise TokenValidationError(f'RSA Public Key with ID "{key_id}" was not found.')
class AsyncTokenVerifier(TokenVerifier):
"""Class that verifies ID tokens following the steps defined in the OpenID Connect spec.
An OpenID Connect ID token is not meant to be consumed until it's verified.
Args:
signature_verifier (AsyncAsymmetricSignatureVerifier): The instance that knows how to verify the signature.
issuer (str): The expected issuer claim value.
audience (str): The expected audience claim value.
leeway (int, optional): The clock skew to accept when verifying date related claims in seconds.
Defaults to 60 seconds.
"""
def __init__(
self,
signature_verifier: AsyncAsymmetricSignatureVerifier,
issuer: str,
audience: str,
leeway: int = 0,
) -> None:
if not signature_verifier or not isinstance(
signature_verifier, AsyncAsymmetricSignatureVerifier
):
raise TypeError(
"signature_verifier must be an instance of AsyncAsymmetricSignatureVerifier."
)
self.iss = issuer
self.aud = audience
self.leeway = leeway
self._sv = signature_verifier
self._clock = None # legacy testing requirement
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
session (aiohttp.ClientSession): The client session which should be closed
manually or within context manager.
"""
self._sv.set_session(session)
async def verify(
self,
token: str,
nonce: str | None = None,
max_age: int | None = None,
organization: str | None = None,
) -> dict[str, Any]:
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.
Args:
token (str): The JWT to verify.
nonce (str, optional): The nonce value sent during authentication.
max_age (int, optional): The max_age value sent during authentication.
organization (str, optional): The expected organization ID (org_id) or organization name (org_name) claim value. This should be specified
when logging in to an organization.
Returns:
the decoded payload from the token
Raises:
TokenValidationError: when the token cannot be decoded, the token signing algorithm is not the expected one,
the token signature is invalid or the token has a claim missing or with unexpected value.
"""
# Verify token presence
if not token or not isinstance(token, str):
raise TokenValidationError("ID token is required but missing.")
# Verify algorithm and signature
payload = await self._sv.verify_signature(token)
# Verify claims
self._verify_payload(payload, nonce, max_age, organization)
return payload
|