File: authenticator_common.py

package info (click to toggle)
python-yalexs 9.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,120 kB
  • sloc: python: 7,916; makefile: 3; sh: 2
file content (172 lines) | stat: -rw-r--r-- 5,467 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from __future__ import annotations

import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any

import jwt

from .api_common import ApiCommon
from .const import HEADER_ACCESS_TOKEN, HEADER_AUGUST_ACCESS_TOKEN
from .time import parse_datetime

# The default time before expiration to refresh a token
DEFAULT_RENEWAL_THRESHOLD = timedelta(days=7)

_LOGGER = logging.getLogger(__name__)


def to_authentication_json(authentication):
    if authentication is None:
        return json.dumps({})

    return json.dumps(
        {
            "install_id": authentication.install_id,
            "access_token": authentication.access_token,
            "access_token_expires": authentication.access_token_expires,
            "state": authentication.state.value,
        }
    )


def from_authentication_json(data):
    if data is None:
        return None

    install_id = data["install_id"]
    access_token = data["access_token"]
    access_token_expires = data["access_token_expires"]
    state = AuthenticationState(data["state"])
    return Authentication(state, install_id, access_token, access_token_expires)


class Authentication:
    def __init__(
        self, state, install_id=None, access_token=None, access_token_expires=None
    ):
        self._state = state
        self._install_id = str(uuid.uuid4()) if install_id is None else install_id
        self._access_token = access_token
        self._access_token_expires = access_token_expires
        self._parsed_expiration_time = None
        if access_token_expires:
            self._parsed_expiration_time = parse_datetime(access_token_expires)

    @property
    def install_id(self):
        return self._install_id

    @property
    def access_token(self):
        return self._access_token

    @property
    def access_token_expires(self):
        return self._access_token_expires

    @property
    def state(self):
        return self._state

    @state.setter
    def state(self, value):
        self._state = value

    def parsed_expiration_time(self):
        return self._parsed_expiration_time

    def is_expired(self):
        return self._parsed_expiration_time < datetime.now(timezone.utc)


class AuthenticationState(Enum):
    REQUIRES_AUTHENTICATION = "requires_authentication"
    REQUIRES_VALIDATION = "requires_validation"
    AUTHENTICATED = "authenticated"
    BAD_PASSWORD = "bad_password"  # nosec


class ValidationResult(Enum):
    VALIDATED = "validated"
    INVALID_VERIFICATION_CODE = "invalid_verification_code"


class AuthenticatorCommon:
    def __init__(
        self,
        api: ApiCommon,
        login_method: str | None,
        username: str | None,
        password: str | None,
        install_id: str | None = None,
        access_token_cache_file: str | None = None,
        access_token_renewal_threshold: timedelta = DEFAULT_RENEWAL_THRESHOLD,
    ) -> None:
        self._api = api
        self._login_method = login_method
        self._username = username
        self._password = password
        self._install_id = install_id
        self._access_token_cache_file = access_token_cache_file
        self._access_token_renewal_threshold = access_token_renewal_threshold
        self._authentication = None

    def _authentication_from_session_response(
        self,
        install_id: str,
        response_headers: dict[str, Any],
        json_dict: dict[str, Any],
    ) -> Authentication:
        access_token = (
            response_headers.get(HEADER_ACCESS_TOKEN)
            or response_headers[HEADER_AUGUST_ACCESS_TOKEN]
        )
        access_token_expires = json_dict["expiresAt"]
        v_password = json_dict["vPassword"]
        v_install_id = json_dict["vInstallId"]

        if not v_password:
            state = AuthenticationState.BAD_PASSWORD
        elif not v_install_id:
            state = AuthenticationState.REQUIRES_VALIDATION
        else:
            state = AuthenticationState.AUTHENTICATED

        self._authentication = Authentication(
            state, install_id, access_token, access_token_expires
        )

        return self._authentication

    def should_refresh(self):
        return self._authentication.state == AuthenticationState.AUTHENTICATED and (
            (self._authentication.parsed_expiration_time() - datetime.now(timezone.utc))
            < self._access_token_renewal_threshold
        )

    def _process_refreshed_access_token(self, refreshed_token):
        jwt_claims = jwt.decode(refreshed_token, options={"verify_signature": False})

        if "exp" not in jwt_claims:
            _LOGGER.warning("Did not find expected `exp' claim in JWT")
            return self._authentication

        new_expiration = datetime.utcfromtimestamp(jwt_claims["exp"])  # noqa: DTZ004
        # The yale access api always returns expiresAt in the format
        # '%Y-%m-%dT%H:%M:%S.%fZ'
        # from the get_session api call
        # It is important we store access_token_expires formatted
        # the same way for compatibility
        self._authentication = Authentication(
            self._authentication.state,
            install_id=self._authentication.install_id,
            access_token=refreshed_token,
            access_token_expires=new_expiration.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
        )

        _LOGGER.info("Successfully refreshed access token")
        return self._authentication