File: jwt_helper.py

package info (click to toggle)
thunderbird 1%3A143.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 4,703,968 kB
  • sloc: cpp: 7,770,492; javascript: 5,943,842; ansic: 3,918,754; python: 1,418,263; xml: 653,354; asm: 474,045; java: 183,079; sh: 111,238; makefile: 20,410; perl: 14,359; objc: 13,059; yacc: 4,583; pascal: 3,405; lex: 1,720; ruby: 999; exp: 762; sql: 715; awk: 580; php: 436; lisp: 430; sed: 69; csh: 10
file content (76 lines) | stat: -rw-r--r-- 3,001 bytes parent folder | download | duplicates (8)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import base64
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding

# This method decodes the JWT and verifies the signature. If a key is provided,
# that will be used for signature verification. Otherwise, the key sent within
# the JWT payload will be used instead.
# This returns a tuple of (decoded_header, decoded_payload, verify_succeeded).
def decode_jwt(token, key=None):
    try:
        # Decode the header and payload.
        header, payload, signature = token.split('.')
        decoded_header = decode_base64_json(header)
        decoded_payload = decode_base64_json(payload)

        # If decoding failed, return nothing.
        if not decoded_header or not decoded_payload:
            return None, None, False

        # If there is a key passed in (for refresh), use that for checking the signature below.
        # Otherwise (for registration), use the key sent within the JWT to check the signature.
        if key == None:
            key = decoded_payload.get('key')
        public_key = serialization.load_pem_public_key(jwk_to_pem(key))
        # Verifying the signature will throw an exception if it fails.
        verify_rs256_signature(header, payload, signature, public_key)
        return decoded_header, decoded_payload, True
    except Exception:
        return None, None, False

def jwk_to_pem(jwk_data):
    jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data
    key_type = jwk.get("kty")

    if key_type != "RSA":
        raise ValueError(f"Unsupported key type: {key_type}")

    n = int.from_bytes(decode_base64url(jwk["n"]), 'big')
    e = int.from_bytes(decode_base64url(jwk["e"]), 'big')
    public_key = rsa.RSAPublicNumbers(e, n).public_key()
    pem_public_key = public_key.public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    )
    return pem_public_key

def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key):
    message = (f'{encoded_header}.{encoded_payload}').encode('utf-8')
    signature_bytes = decode_base64(signature)
    # This will throw an exception if verification fails.
    public_key.verify(
        signature_bytes,
        message,
        padding.PKCS1v15(),
        hashes.SHA256()
    )

def add_base64_padding(encoded_data):
    remainder = len(encoded_data) % 4
    if remainder > 0:
        encoded_data += '=' * (4 - remainder)
    return encoded_data

def decode_base64url(encoded_data):
    encoded_data = add_base64_padding(encoded_data)
    encoded_data = encoded_data.replace("-", "+").replace("_", "/")
    return base64.b64decode(encoded_data)

def decode_base64(encoded_data):
    encoded_data = add_base64_padding(encoded_data)
    return base64.urlsafe_b64decode(encoded_data)

def decode_base64_json(encoded_data):
    return json.loads(decode_base64(encoded_data))