File: jwt.py

package info (click to toggle)
python-authlib 1.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,016 kB
  • sloc: python: 26,998; makefile: 53; sh: 14
file content (191 lines) | stat: -rw-r--r-- 6,185 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import calendar
import datetime
import random
import re

from authlib.common.encoding import json_dumps
from authlib.common.encoding import json_loads
from authlib.common.encoding import to_bytes
from authlib.common.encoding import to_unicode

from ..errors import DecodeError
from ..errors import InsecureClaimError
from ..rfc7515 import JsonWebSignature
from ..rfc7516 import JsonWebEncryption
from ..rfc7517 import Key
from ..rfc7517 import KeySet
from .claims import JWTClaims


class JsonWebToken:
    SENSITIVE_NAMES = ("password", "token", "secret", "secret_key")
    # Thanks to sentry SensitiveDataFilter
    SENSITIVE_VALUES = re.compile(
        r"|".join(
            [
                # http://www.richardsramblings.com/regex/credit-card-numbers/
                r"\b(?:3[47]\d|(?:4\d|5[1-5]|65)\d{2}|6011)\d{12}\b",
                # various private keys
                r"-----BEGIN[A-Z ]+PRIVATE KEY-----.+-----END[A-Z ]+PRIVATE KEY-----",
                # social security numbers (US)
                r"^\b(?!(000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}\b",
            ]
        ),
        re.DOTALL,
    )

    def __init__(self, algorithms, private_headers=None):
        self._jws = JsonWebSignature(algorithms, private_headers=private_headers)
        self._jwe = JsonWebEncryption(algorithms, private_headers=private_headers)

    def check_sensitive_data(self, payload):
        """Check if payload contains sensitive information."""
        for k in payload:
            # check claims key name
            if k in self.SENSITIVE_NAMES:
                raise InsecureClaimError(k)

            # check claims values
            v = payload[k]
            if isinstance(v, str) and self.SENSITIVE_VALUES.search(v):
                raise InsecureClaimError(k)

    def encode(self, header, payload, key, check=True):
        """Encode a JWT with the given header, payload and key.

        :param header: A dict of JWS header
        :param payload: A dict to be encoded
        :param key: key used to sign the signature
        :param check: check if sensitive data in payload
        :return: bytes
        """
        header.setdefault("typ", "JWT")

        for k in ["exp", "iat", "nbf"]:
            # convert datetime into timestamp
            claim = payload.get(k)
            if isinstance(claim, datetime.datetime):
                payload[k] = calendar.timegm(claim.utctimetuple())

        if check:
            self.check_sensitive_data(payload)

        key = find_encode_key(key, header)
        text = to_bytes(json_dumps(payload))
        if "enc" in header:
            return self._jwe.serialize_compact(header, text, key)
        else:
            return self._jws.serialize_compact(header, text, key)

    def decode(self, s, key, claims_cls=None, claims_options=None, claims_params=None):
        """Decode the JWT with the given key. This is similar with
        :meth:`verify`, except that it will raise BadSignatureError when
        signature doesn't match.

        :param s: text of JWT
        :param key: key used to verify the signature
        :param claims_cls: class to be used for JWT claims
        :param claims_options: `options` parameters for claims_cls
        :param claims_params: `params` parameters for claims_cls
        :return: claims_cls instance
        :raise: BadSignatureError
        """
        if claims_cls is None:
            claims_cls = JWTClaims

        if callable(key):
            load_key = key
        else:
            load_key = create_load_key(prepare_raw_key(key))

        s = to_bytes(s)
        dot_count = s.count(b".")
        if dot_count == 2:
            data = self._jws.deserialize_compact(s, load_key, decode_payload)
        elif dot_count == 4:
            data = self._jwe.deserialize_compact(s, load_key, decode_payload)
        else:
            raise DecodeError("Invalid input segments length")
        return claims_cls(
            data["payload"],
            data["header"],
            options=claims_options,
            params=claims_params,
        )


def decode_payload(bytes_payload):
    try:
        payload = json_loads(to_unicode(bytes_payload))
    except ValueError as exc:
        raise DecodeError("Invalid payload value") from exc
    if not isinstance(payload, dict):
        raise DecodeError("Invalid payload type")
    return payload


def prepare_raw_key(raw):
    if isinstance(raw, KeySet):
        return raw

    if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
        raw = json_loads(raw)
    elif isinstance(raw, (tuple, list)):
        raw = {"keys": raw}
    return raw


def find_encode_key(key, header):
    if isinstance(key, KeySet):
        kid = header.get("kid")
        if kid:
            return key.find_by_kid(kid)

        rv = random.choice(key.keys)
        # use side effect to add kid value into header
        header["kid"] = rv.kid
        return rv

    if isinstance(key, dict) and "keys" in key:
        keys = key["keys"]
        kid = header.get("kid")
        for k in keys:
            if k.get("kid") == kid:
                return k

        if not kid:
            rv = random.choice(keys)
            header["kid"] = rv["kid"]
            return rv
        raise ValueError("Invalid JSON Web Key Set")

    # append kid into header
    if isinstance(key, dict) and "kid" in key:
        header["kid"] = key["kid"]
    elif isinstance(key, Key) and key.kid:
        header["kid"] = key.kid
    return key


def create_load_key(key):
    def load_key(header, payload):
        if isinstance(key, KeySet):
            return key.find_by_kid(header.get("kid"))

        if isinstance(key, dict) and "keys" in key:
            keys = key["keys"]
            kid = header.get("kid")

            if kid is not None:
                # look for the requested key
                for k in keys:
                    if k.get("kid") == kid:
                        return k
            else:
                # use the only key
                if len(keys) == 1:
                    return keys[0]
            raise ValueError("Invalid JSON Web Key Set")
        return key

    return load_key