File: token.py

package info (click to toggle)
python-redis 6.4.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 9,432 kB
  • sloc: python: 60,318; sh: 179; makefile: 128
file content (130 lines) | stat: -rw-r--r-- 3,317 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from abc import ABC, abstractmethod
from datetime import datetime, timezone

from redis.auth.err import InvalidTokenSchemaErr


class TokenInterface(ABC):
    @abstractmethod
    def is_expired(self) -> bool:
        pass

    @abstractmethod
    def ttl(self) -> float:
        pass

    @abstractmethod
    def try_get(self, key: str) -> str:
        pass

    @abstractmethod
    def get_value(self) -> str:
        pass

    @abstractmethod
    def get_expires_at_ms(self) -> float:
        pass

    @abstractmethod
    def get_received_at_ms(self) -> float:
        pass


class TokenResponse:
    def __init__(self, token: TokenInterface):
        self._token = token

    def get_token(self) -> TokenInterface:
        return self._token

    def get_ttl_ms(self) -> float:
        return self._token.get_expires_at_ms() - self._token.get_received_at_ms()


class SimpleToken(TokenInterface):
    def __init__(
        self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
    ) -> None:
        self.value = value
        self.expires_at = expires_at_ms
        self.received_at = received_at_ms
        self.claims = claims

    def ttl(self) -> float:
        if self.expires_at == -1:
            return -1

        return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)

    def is_expired(self) -> bool:
        if self.expires_at == -1:
            return False

        return self.ttl() <= 0

    def try_get(self, key: str) -> str:
        return self.claims.get(key)

    def get_value(self) -> str:
        return self.value

    def get_expires_at_ms(self) -> float:
        return self.expires_at

    def get_received_at_ms(self) -> float:
        return self.received_at


class JWToken(TokenInterface):
    REQUIRED_FIELDS = {"exp"}

    def __init__(self, token: str):
        try:
            import jwt
        except ImportError as ie:
            raise ImportError(
                f"The PyJWT library is required for {self.__class__.__name__}.",
            ) from ie
        self._value = token
        self._decoded = jwt.decode(
            self._value,
            options={"verify_signature": False},
            algorithms=[jwt.get_unverified_header(self._value).get("alg")],
        )
        self._validate_token()

    def is_expired(self) -> bool:
        exp = self._decoded["exp"]
        if exp == -1:
            return False

        return (
            self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
        )

    def ttl(self) -> float:
        exp = self._decoded["exp"]
        if exp == -1:
            return -1

        return (
            self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
        )

    def try_get(self, key: str) -> str:
        return self._decoded.get(key)

    def get_value(self) -> str:
        return self._value

    def get_expires_at_ms(self) -> float:
        return float(self._decoded["exp"] * 1000)

    def get_received_at_ms(self) -> float:
        return datetime.now(timezone.utc).timestamp() * 1000

    def _validate_token(self):
        actual_fields = {x for x in self._decoded.keys()}

        if len(self.REQUIRED_FIELDS - actual_fields) != 0:
            raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)