1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
import time
import binascii
import base64
import uuid
import logging
logger = logging.getLogger(__name__)
def _str2bytes(raw):
# A conversion based on duck-typing rather than six.text_type
try: # Assuming it is a string
return raw.encode(encoding="utf-8")
except: # Otherwise we treat it as bytes and return it as-is
return raw
def _encode_thumbprint(thumbprint):
return base64.urlsafe_b64encode(binascii.a2b_hex(thumbprint)).decode()
class AssertionCreator(object):
def create_normal_assertion(
self, audience, issuer, subject, expires_at=None, expires_in=600,
issued_at=None, assertion_id=None, **kwargs):
"""Create an assertion in bytes, based on the provided claims.
All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5
except the expires_in is defined here as lifetime-in-seconds,
which will be automatically translated into expires_at in UTC.
"""
raise NotImplementedError("Will be implemented by sub-class")
def create_regenerative_assertion(
self, audience, issuer, subject=None, expires_in=600, **kwargs):
"""Create an assertion as a callable,
which will then compute the assertion later when necessary.
This is a useful optimization to reuse the client assertion.
"""
return AutoRefresher( # Returns a callable
lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs:
self.create_normal_assertion(a, i, s, expires_in=e, **kwargs),
expires_in=max(expires_in-60, 0))
class AutoRefresher(object):
"""Cache the output of a factory, and auto-refresh it when necessary. Usage::
r = AutoRefresher(time.time, expires_in=5)
for i in range(15):
print(r()) # the timestamp change only after every 5 seconds
time.sleep(1)
"""
def __init__(self, factory, expires_in=540):
self._factory = factory
self._expires_in = expires_in
self._buf = {}
def __call__(self):
EXPIRES_AT, VALUE = "expires_at", "value"
now = time.time()
if self._buf.get(EXPIRES_AT, 0) <= now:
logger.debug("Regenerating new assertion")
self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in}
else:
logger.debug("Reusing still valid assertion")
return self._buf.get(VALUE)
class JwtAssertionCreator(AssertionCreator):
def __init__(
self, key, algorithm, sha1_thumbprint=None, headers=None,
*,
sha256_thumbprint=None,
):
"""Construct a Jwt assertion creator.
Args:
key (str):
An unencrypted private key for signing, in a base64 encoded string.
It can also be a cryptography ``PrivateKey`` object,
which is how you can work with a previously-encrypted key.
See also https://github.com/jpadilla/pyjwt/pull/525
algorithm (str):
"RS256", etc.. See https://pyjwt.readthedocs.io/en/latest/algorithms.html
RSA and ECDSA algorithms require "pip install cryptography".
sha1_thumbprint (str): The x5t aka X.509 certificate SHA-1 thumbprint.
headers (dict): Additional headers, e.g. "kid" or "x5c" etc.
sha256_thumbprint (str): The x5t#S256 aka X.509 certificate SHA-256 thumbprint.
"""
self.key = key
self.algorithm = algorithm
self.headers = headers or {}
if sha256_thumbprint: # https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.8
self.headers["x5t#S256"] = _encode_thumbprint(sha256_thumbprint)
if sha1_thumbprint: # https://tools.ietf.org/html/rfc7515#section-4.1.7
self.headers["x5t"] = _encode_thumbprint(sha1_thumbprint)
def create_normal_assertion(
self, audience, issuer, subject=None, expires_at=None, expires_in=600,
issued_at=None, assertion_id=None, not_before=None,
additional_claims=None, **kwargs):
"""Create a JWT Assertion.
Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
Key-value pairs in additional_claims will be added into payload as-is.
"""
import jwt # Lazy loading
now = time.time()
payload = {
'aud': audience,
'iss': issuer,
'sub': subject or issuer,
'exp': expires_at or (now + expires_in),
'iat': issued_at or now,
'jti': assertion_id or str(uuid.uuid4()),
}
if not_before:
payload['nbf'] = not_before
payload.update(additional_claims or {})
try:
str_or_bytes = jwt.encode( # PyJWT 1 returns bytes, PyJWT 2 returns str
payload, self.key, algorithm=self.algorithm, headers=self.headers)
return _str2bytes(str_or_bytes) # We normalize them into bytes
except:
if self.algorithm.startswith("RS") or self.algorithm.startswith("ES"):
logger.exception(
'Some algorithms requires "pip install cryptography". '
'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
raise
# Obsolete. For backward compatibility. They will be removed in future versions.
Signer = AssertionCreator # For backward compatibility
JwtSigner = JwtAssertionCreator # For backward compatibility
JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion # For backward compatibility
|