File: api_jwt.py

package info (click to toggle)
pyjwt 2.11.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 864 kB
  • sloc: python: 5,151; makefile: 141
file content (596 lines) | stat: -rw-r--r-- 22,141 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
from __future__ import annotations

import json
import os
import warnings
from calendar import timegm
from collections.abc import Container, Iterable, Sequence
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Union, cast

from .api_jws import PyJWS, _jws_global_obj
from .exceptions import (
    DecodeError,
    ExpiredSignatureError,
    ImmatureSignatureError,
    InvalidAudienceError,
    InvalidIssuedAtError,
    InvalidIssuerError,
    InvalidJTIError,
    InvalidSubjectError,
    MissingRequiredClaimError,
)
from .warnings import RemovedInPyjwt3Warning

if TYPE_CHECKING or bool(os.getenv("SPHINX_BUILD", "")):
    import sys

    if sys.version_info >= (3, 10):
        from typing import TypeAlias
    else:
        # Python 3.9 and lower
        from typing_extensions import TypeAlias

    from .algorithms import has_crypto
    from .api_jwk import PyJWK
    from .types import FullOptions, Options, SigOptions

    if has_crypto:
        from .algorithms import AllowedPrivateKeys, AllowedPublicKeys

        AllowedPrivateKeyTypes: TypeAlias = Union[AllowedPrivateKeys, PyJWK, str, bytes]
        AllowedPublicKeyTypes: TypeAlias = Union[AllowedPublicKeys, PyJWK, str, bytes]
    else:
        AllowedPrivateKeyTypes: TypeAlias = Union[PyJWK, str, bytes]  # type: ignore
        AllowedPublicKeyTypes: TypeAlias = Union[PyJWK, str, bytes]  # type: ignore


class PyJWT:
    def __init__(self, options: Options | None = None) -> None:
        self.options: FullOptions
        self.options = self._get_default_options()
        if options is not None:
            self.options = self._merge_options(options)

        self._jws = PyJWS(options=self._get_sig_options())

    @staticmethod
    def _get_default_options() -> FullOptions:
        return {
            "verify_signature": True,
            "verify_exp": True,
            "verify_nbf": True,
            "verify_iat": True,
            "verify_aud": True,
            "verify_iss": True,
            "verify_sub": True,
            "verify_jti": True,
            "require": [],
            "strict_aud": False,
            "enforce_minimum_key_length": False,
        }

    def _get_sig_options(self) -> SigOptions:
        return {
            "verify_signature": self.options["verify_signature"],
            "enforce_minimum_key_length": self.options.get(
                "enforce_minimum_key_length", False
            ),
        }

    def _merge_options(self, options: Options | None = None) -> FullOptions:
        if options is None:
            return self.options

        # (defensive) set defaults for verify_x to False if verify_signature is False
        if not options.get("verify_signature", True):
            options["verify_exp"] = options.get("verify_exp", False)
            options["verify_nbf"] = options.get("verify_nbf", False)
            options["verify_iat"] = options.get("verify_iat", False)
            options["verify_aud"] = options.get("verify_aud", False)
            options["verify_iss"] = options.get("verify_iss", False)
            options["verify_sub"] = options.get("verify_sub", False)
            options["verify_jti"] = options.get("verify_jti", False)
        return {**self.options, **options}

    def encode(
        self,
        payload: dict[str, Any],
        key: AllowedPrivateKeyTypes,
        algorithm: str | None = "HS256",
        headers: dict[str, Any] | None = None,
        json_encoder: type[json.JSONEncoder] | None = None,
        sort_headers: bool = True,
    ) -> str:
        """Encode the ``payload`` as JSON Web Token.

        :param payload: JWT claims, e.g. ``dict(iss=..., aud=..., sub=...)``
        :type payload: dict[str, typing.Any]
        :param key: a key suitable for the chosen algorithm:

            * for **asymmetric algorithms**: PEM-formatted private key, a multiline string
            * for **symmetric algorithms**: plain string, sufficiently long for security

        :type key: str or bytes or PyJWK or :py:class:`jwt.algorithms.AllowedPrivateKeys`
        :param algorithm: algorithm to sign the token with, e.g. ``"ES256"``.
            If ``headers`` includes ``alg``, it will be preferred to this parameter.
            If ``key`` is a :class:`PyJWK` object, by default the key algorithm will be used.
        :type algorithm: str or None
        :param headers: additional JWT header fields, e.g. ``dict(kid="my-key-id")``.
        :type headers: dict[str, typing.Any] or None
        :param json_encoder: custom JSON encoder for ``payload`` and ``headers``
        :type json_encoder: json.JSONEncoder or None

        :rtype: str
        :returns: a JSON Web Token

        :raises TypeError: if ``payload`` is not a ``dict``
        """
        # Check that we get a dict
        if not isinstance(payload, dict):
            raise TypeError(
                "Expecting a dict object, as JWT only supports "
                "JSON objects as payloads."
            )

        # Payload
        payload = payload.copy()
        for time_claim in ["exp", "iat", "nbf"]:
            # Convert datetime to a intDate value in known time-format claims
            if isinstance(payload.get(time_claim), datetime):
                payload[time_claim] = timegm(payload[time_claim].utctimetuple())

        # Issue #1039, iss being set to non-string
        if "iss" in payload and not isinstance(payload["iss"], str):
            raise TypeError("Issuer (iss) must be a string.")

        json_payload = self._encode_payload(
            payload,
            headers=headers,
            json_encoder=json_encoder,
        )

        return self._jws.encode(
            json_payload,
            key,
            algorithm,
            headers,
            json_encoder,
            sort_headers=sort_headers,
        )

    def _encode_payload(
        self,
        payload: dict[str, Any],
        headers: dict[str, Any] | None = None,
        json_encoder: type[json.JSONEncoder] | None = None,
    ) -> bytes:
        """
        Encode a given payload to the bytes to be signed.

        This method is intended to be overridden by subclasses that need to
        encode the payload in a different way, e.g. compress the payload.
        """
        return json.dumps(
            payload,
            separators=(",", ":"),
            cls=json_encoder,
        ).encode("utf-8")

    def decode_complete(
        self,
        jwt: str | bytes,
        key: AllowedPublicKeyTypes = "",
        algorithms: Sequence[str] | None = None,
        options: Options | None = None,
        # deprecated arg, remove in pyjwt3
        verify: bool | None = None,
        # could be used as passthrough to api_jws, consider removal in pyjwt3
        detached_payload: bytes | None = None,
        # passthrough arguments to _validate_claims
        # consider putting in options
        audience: str | Iterable[str] | None = None,
        issuer: str | Container[str] | None = None,
        subject: str | None = None,
        leeway: float | timedelta = 0,
        # kwargs
        **kwargs: Any,
    ) -> dict[str, Any]:
        """Identical to ``jwt.decode`` except for return value which is a dictionary containing the token header (JOSE Header),
        the token payload (JWT Payload), and token signature (JWT Signature) on the keys "header", "payload",
        and "signature" respectively.

        :param jwt: the token to be decoded
        :type jwt: str or bytes
        :param key: the key suitable for the allowed algorithm
        :type key: str or bytes or PyJWK or :py:class:`jwt.algorithms.AllowedPublicKeys`

        :param algorithms: allowed algorithms, e.g. ``["ES256"]``

            .. warning::

               Do **not** compute the ``algorithms`` parameter based on
               the ``alg`` from the token itself, or on any other data
               that an attacker may be able to influence, as that might
               expose you to various vulnerabilities (see `RFC 8725 §2.1
               <https://www.rfc-editor.org/rfc/rfc8725.html#section-2.1>`_). Instead,
               either hard-code a fixed value for ``algorithms``, or
               configure it in the same place you configure the
               ``key``. Make sure not to mix symmetric and asymmetric
               algorithms that interpret the ``key`` in different ways
               (e.g. HS\\* and RS\\*).
        :type algorithms: typing.Sequence[str] or None

        :param jwt.types.Options options: extended decoding and validation options
            Refer to :py:class:`jwt.types.Options` for more information.

        :param audience: optional, the value for ``verify_aud`` check
        :type audience: str or typing.Iterable[str] or None
        :param issuer: optional, the value for ``verify_iss`` check
        :type issuer: str or typing.Container[str] or None
        :param leeway: a time margin in seconds for the expiration check
        :type leeway: float or datetime.timedelta
        :rtype: dict[str, typing.Any]
        :returns: Decoded JWT with the JOSE Header on the key ``header``, the JWS
         Payload on the key ``payload``, and the JWS Signature on the key ``signature``.
        """
        if kwargs:
            warnings.warn(
                "passing additional kwargs to decode_complete() is deprecated "
                "and will be removed in pyjwt version 3. "
                f"Unsupported kwargs: {tuple(kwargs.keys())}",
                RemovedInPyjwt3Warning,
                stacklevel=2,
            )

        if options is None:
            verify_signature = True
        else:
            verify_signature = options.get("verify_signature", True)

        # If the user has set the legacy `verify` argument, and it doesn't match
        # what the relevant `options` entry for the argument is, inform the user
        # that they're likely making a mistake.
        if verify is not None and verify != verify_signature:
            warnings.warn(
                "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
                "The equivalent is setting `verify_signature` to False in the `options` dictionary. "
                "This invocation has a mismatch between the kwarg and the option entry.",
                category=DeprecationWarning,
                stacklevel=2,
            )

        merged_options = self._merge_options(options)

        sig_options: SigOptions = {
            "verify_signature": verify_signature,
        }
        decoded = self._jws.decode_complete(
            jwt,
            key=key,
            algorithms=algorithms,
            options=sig_options,
            detached_payload=detached_payload,
        )

        payload = self._decode_payload(decoded)

        self._validate_claims(
            payload,
            merged_options,
            audience=audience,
            issuer=issuer,
            leeway=leeway,
            subject=subject,
        )

        decoded["payload"] = payload
        return decoded

    def _decode_payload(self, decoded: dict[str, Any]) -> dict[str, Any]:
        """
        Decode the payload from a JWS dictionary (payload, signature, header).

        This method is intended to be overridden by subclasses that need to
        decode the payload in a different way, e.g. decompress compressed
        payloads.
        """
        try:
            payload: dict[str, Any] = json.loads(decoded["payload"])
        except ValueError as e:
            raise DecodeError(f"Invalid payload string: {e}") from e
        if not isinstance(payload, dict):
            raise DecodeError("Invalid payload string: must be a json object")
        return payload

    def decode(
        self,
        jwt: str | bytes,
        key: AllowedPublicKeys | PyJWK | str | bytes = "",
        algorithms: Sequence[str] | None = None,
        options: Options | None = None,
        # deprecated arg, remove in pyjwt3
        verify: bool | None = None,
        # could be used as passthrough to api_jws, consider removal in pyjwt3
        detached_payload: bytes | None = None,
        # passthrough arguments to _validate_claims
        # consider putting in options
        audience: str | Iterable[str] | None = None,
        subject: str | None = None,
        issuer: str | Container[str] | None = None,
        leeway: float | timedelta = 0,
        # kwargs
        **kwargs: Any,
    ) -> dict[str, Any]:
        """Verify the ``jwt`` token signature and return the token claims.

        :param jwt: the token to be decoded
        :type jwt: str or bytes
        :param key: the key suitable for the allowed algorithm
        :type key: str or bytes or PyJWK or :py:class:`jwt.algorithms.AllowedPublicKeys`

        :param algorithms: allowed algorithms, e.g. ``["ES256"]``
            If ``key`` is a :class:`PyJWK` object, allowed algorithms will default to the key algorithm.

            .. warning::

               Do **not** compute the ``algorithms`` parameter based on
               the ``alg`` from the token itself, or on any other data
               that an attacker may be able to influence, as that might
               expose you to various vulnerabilities (see `RFC 8725 §2.1
               <https://www.rfc-editor.org/rfc/rfc8725.html#section-2.1>`_). Instead,
               either hard-code a fixed value for ``algorithms``, or
               configure it in the same place you configure the
               ``key``. Make sure not to mix symmetric and asymmetric
               algorithms that interpret the ``key`` in different ways
               (e.g. HS\\* and RS\\*).
        :type algorithms: typing.Sequence[str] or None

        :param jwt.types.Options options: extended decoding and validation options
            Refer to :py:class:`jwt.types.Options` for more information.

        :param audience: optional, the value for ``verify_aud`` check
        :type audience: str or typing.Iterable[str] or None
        :param subject: optional, the value for ``verify_sub`` check
        :type subject: str or None
        :param issuer: optional, the value for ``verify_iss`` check
        :type issuer: str or typing.Container[str] or None
        :param leeway: a time margin in seconds for the expiration check
        :type leeway: float or datetime.timedelta
        :rtype: dict[str, typing.Any]
        :returns: the JWT claims
        """
        if kwargs:
            warnings.warn(
                "passing additional kwargs to decode() is deprecated "
                "and will be removed in pyjwt version 3. "
                f"Unsupported kwargs: {tuple(kwargs.keys())}",
                RemovedInPyjwt3Warning,
                stacklevel=2,
            )
        decoded = self.decode_complete(
            jwt,
            key,
            algorithms,
            options,
            verify=verify,
            detached_payload=detached_payload,
            audience=audience,
            subject=subject,
            issuer=issuer,
            leeway=leeway,
        )
        return cast(dict[str, Any], decoded["payload"])

    def _validate_claims(
        self,
        payload: dict[str, Any],
        options: FullOptions,
        audience: Iterable[str] | str | None = None,
        issuer: Container[str] | str | None = None,
        subject: str | None = None,
        leeway: float | timedelta = 0,
    ) -> None:
        if isinstance(leeway, timedelta):
            leeway = leeway.total_seconds()

        if audience is not None and not isinstance(audience, (str, Iterable)):
            raise TypeError("audience must be a string, iterable or None")

        self._validate_required_claims(payload, options["require"])

        now = datetime.now(tz=timezone.utc).timestamp()

        if "iat" in payload and options["verify_iat"]:
            self._validate_iat(payload, now, leeway)

        if "nbf" in payload and options["verify_nbf"]:
            self._validate_nbf(payload, now, leeway)

        if "exp" in payload and options["verify_exp"]:
            self._validate_exp(payload, now, leeway)

        if options["verify_iss"]:
            self._validate_iss(payload, issuer)

        if options["verify_aud"]:
            self._validate_aud(
                payload, audience, strict=options.get("strict_aud", False)
            )

        if options["verify_sub"]:
            self._validate_sub(payload, subject)

        if options["verify_jti"]:
            self._validate_jti(payload)

    def _validate_required_claims(
        self,
        payload: dict[str, Any],
        claims: Iterable[str],
    ) -> None:
        for claim in claims:
            if payload.get(claim) is None:
                raise MissingRequiredClaimError(claim)

    def _validate_sub(
        self, payload: dict[str, Any], subject: str | None = None
    ) -> None:
        """
        Checks whether "sub" if in the payload is valid or not.
        This is an Optional claim

        :param payload(dict): The payload which needs to be validated
        :param subject(str): The subject of the token
        """

        if "sub" not in payload:
            return

        if not isinstance(payload["sub"], str):
            raise InvalidSubjectError("Subject must be a string")

        if subject is not None:
            if payload.get("sub") != subject:
                raise InvalidSubjectError("Invalid subject")

    def _validate_jti(self, payload: dict[str, Any]) -> None:
        """
        Checks whether "jti" if in the payload is valid or not
        This is an Optional claim

        :param payload(dict): The payload which needs to be validated
        """

        if "jti" not in payload:
            return

        if not isinstance(payload.get("jti"), str):
            raise InvalidJTIError("JWT ID must be a string")

    def _validate_iat(
        self,
        payload: dict[str, Any],
        now: float,
        leeway: float,
    ) -> None:
        try:
            iat = int(payload["iat"])
        except ValueError:
            raise InvalidIssuedAtError(
                "Issued At claim (iat) must be an integer."
            ) from None
        if iat > (now + leeway):
            raise ImmatureSignatureError("The token is not yet valid (iat)")

    def _validate_nbf(
        self,
        payload: dict[str, Any],
        now: float,
        leeway: float,
    ) -> None:
        try:
            nbf = int(payload["nbf"])
        except ValueError:
            raise DecodeError("Not Before claim (nbf) must be an integer.") from None

        if nbf > (now + leeway):
            raise ImmatureSignatureError("The token is not yet valid (nbf)")

    def _validate_exp(
        self,
        payload: dict[str, Any],
        now: float,
        leeway: float,
    ) -> None:
        try:
            exp = int(payload["exp"])
        except ValueError:
            raise DecodeError(
                "Expiration Time claim (exp) must be an integer."
            ) from None

        if exp <= (now - leeway):
            raise ExpiredSignatureError("Signature has expired")

    def _validate_aud(
        self,
        payload: dict[str, Any],
        audience: str | Iterable[str] | None,
        *,
        strict: bool = False,
    ) -> None:
        if audience is None:
            if "aud" not in payload or not payload["aud"]:
                return
            # Application did not specify an audience, but
            # the token has the 'aud' claim
            raise InvalidAudienceError("Invalid audience")

        if "aud" not in payload or not payload["aud"]:
            # Application specified an audience, but it could not be
            # verified since the token does not contain a claim.
            raise MissingRequiredClaimError("aud")

        audience_claims = payload["aud"]

        # In strict mode, we forbid list matching: the supplied audience
        # must be a string, and it must exactly match the audience claim.
        if strict:
            # Only a single audience is allowed in strict mode.
            if not isinstance(audience, str):
                raise InvalidAudienceError("Invalid audience (strict)")

            # Only a single audience claim is allowed in strict mode.
            if not isinstance(audience_claims, str):
                raise InvalidAudienceError("Invalid claim format in token (strict)")

            if audience != audience_claims:
                raise InvalidAudienceError("Audience doesn't match (strict)")

            return

        if isinstance(audience_claims, str):
            audience_claims = [audience_claims]
        if not isinstance(audience_claims, list):
            raise InvalidAudienceError("Invalid claim format in token")
        if any(not isinstance(c, str) for c in audience_claims):
            raise InvalidAudienceError("Invalid claim format in token")

        if isinstance(audience, str):
            audience = [audience]

        if all(aud not in audience_claims for aud in audience):
            raise InvalidAudienceError("Audience doesn't match")

    def _validate_iss(
        self, payload: dict[str, Any], issuer: Container[str] | str | None
    ) -> None:
        if issuer is None:
            return

        if "iss" not in payload:
            raise MissingRequiredClaimError("iss")

        iss = payload["iss"]
        if not isinstance(iss, str):
            raise InvalidIssuerError("Payload Issuer (iss) must be a string")

        if isinstance(issuer, str):
            if iss != issuer:
                raise InvalidIssuerError("Invalid issuer")
        else:
            try:
                if iss not in issuer:
                    raise InvalidIssuerError("Invalid issuer")
            except TypeError:
                raise InvalidIssuerError(
                    'Issuer param must be "str" or "Container[str]"'
                ) from None


_jwt_global_obj = PyJWT()
_jwt_global_obj._jws = _jws_global_obj
encode = _jwt_global_obj.encode
decode_complete = _jwt_global_obj.decode_complete
decode = _jwt_global_obj.decode