File: assertion.py

package info (click to toggle)
microsoft-authentication-library-for-python 1.34.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,320 kB
  • sloc: python: 8,613; xml: 2,783; sh: 27; makefile: 19
file content (137 lines) | stat: -rw-r--r-- 5,690 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import time
import binascii
import base64
import uuid
import logging


logger = logging.getLogger(__name__)


def _str2bytes(raw):
    # A conversion based on duck-typing rather than six.text_type
    try:  # Assuming it is a string
        return raw.encode(encoding="utf-8")
    except:  # Otherwise we treat it as bytes and return it as-is
        return raw

def _encode_thumbprint(thumbprint):
    return base64.urlsafe_b64encode(binascii.a2b_hex(thumbprint)).decode()

class AssertionCreator(object):
    def create_normal_assertion(
            self, audience, issuer, subject, expires_at=None, expires_in=600,
            issued_at=None, assertion_id=None, **kwargs):
        """Create an assertion in bytes, based on the provided claims.

        All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5
        except the expires_in is defined here as lifetime-in-seconds,
        which will be automatically translated into expires_at in UTC.
        """
        raise NotImplementedError("Will be implemented by sub-class")

    def create_regenerative_assertion(
            self, audience, issuer, subject=None, expires_in=600, **kwargs):
        """Create an assertion as a callable,
        which will then compute the assertion later when necessary.

        This is a useful optimization to reuse the client assertion.
        """
        return AutoRefresher(  # Returns a callable
            lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs:
                self.create_normal_assertion(a, i, s, expires_in=e, **kwargs),
            expires_in=max(expires_in-60, 0))


class AutoRefresher(object):
    """Cache the output of a factory, and auto-refresh it when necessary. Usage::

        r = AutoRefresher(time.time, expires_in=5)
        for i in range(15):
            print(r())  # the timestamp change only after every 5 seconds
            time.sleep(1)
    """
    def __init__(self, factory, expires_in=540):
        self._factory = factory
        self._expires_in = expires_in
        self._buf = {}
    def __call__(self):
        EXPIRES_AT, VALUE = "expires_at", "value"
        now = time.time()
        if self._buf.get(EXPIRES_AT, 0) <= now:
            logger.debug("Regenerating new assertion")
            self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in}
        else:
            logger.debug("Reusing still valid assertion")
        return self._buf.get(VALUE)


class JwtAssertionCreator(AssertionCreator):
    def __init__(
        self, key, algorithm, sha1_thumbprint=None, headers=None,
        *,
        sha256_thumbprint=None,
    ):
        """Construct a Jwt assertion creator.

        Args:

            key (str):
                An unencrypted private key for signing, in a base64 encoded string.
                It can also be a cryptography ``PrivateKey`` object,
                which is how you can work with a previously-encrypted key.
                See also https://github.com/jpadilla/pyjwt/pull/525
            algorithm (str):
                "RS256", etc.. See https://pyjwt.readthedocs.io/en/latest/algorithms.html
                RSA and ECDSA algorithms require "pip install cryptography".
            sha1_thumbprint (str): The x5t aka X.509 certificate SHA-1 thumbprint.
            headers (dict): Additional headers, e.g. "kid" or "x5c" etc.
            sha256_thumbprint (str): The x5t#S256 aka X.509 certificate SHA-256 thumbprint.
        """
        self.key = key
        self.algorithm = algorithm
        self.headers = headers or {}
        if sha256_thumbprint:  # https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.8
            self.headers["x5t#S256"] = _encode_thumbprint(sha256_thumbprint)
        if sha1_thumbprint:  # https://tools.ietf.org/html/rfc7515#section-4.1.7
            self.headers["x5t"] = _encode_thumbprint(sha1_thumbprint)

    def create_normal_assertion(
            self, audience, issuer, subject=None, expires_at=None, expires_in=600,
            issued_at=None, assertion_id=None, not_before=None,
            additional_claims=None, **kwargs):
        """Create a JWT Assertion.

        Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
        Key-value pairs in additional_claims will be added into payload as-is.
        """
        import jwt  # Lazy loading
        now = time.time()
        payload = {
            'aud': audience,
            'iss': issuer,
            'sub': subject or issuer,
            'exp': expires_at or (now + expires_in),
            'iat': issued_at or now,
            'jti': assertion_id or str(uuid.uuid4()),
            }
        if not_before:
            payload['nbf'] = not_before
        payload.update(additional_claims or {})
        try:
            str_or_bytes = jwt.encode(  # PyJWT 1 returns bytes, PyJWT 2 returns str
                payload, self.key, algorithm=self.algorithm, headers=self.headers)
            return _str2bytes(str_or_bytes)  # We normalize them into bytes
        except:
            if self.algorithm.startswith("RS") or self.algorithm.startswith("ES"):
                logger.exception(
                    'Some algorithms requires "pip install cryptography". '
                    'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
            raise


# Obsolete. For backward compatibility. They will be removed in future versions.
Signer = AssertionCreator  # For backward compatibility
JwtSigner = JwtAssertionCreator  # For backward compatibility
JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion  # For backward compatibility