1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
|
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any
import jwt
from .api_common import ApiCommon
from .const import HEADER_ACCESS_TOKEN, HEADER_AUGUST_ACCESS_TOKEN
from .time import parse_datetime
# The default time before expiration to refresh a token
DEFAULT_RENEWAL_THRESHOLD = timedelta(days=7)
_LOGGER = logging.getLogger(__name__)
def to_authentication_json(authentication):
if authentication is None:
return json.dumps({})
return json.dumps(
{
"install_id": authentication.install_id,
"access_token": authentication.access_token,
"access_token_expires": authentication.access_token_expires,
"state": authentication.state.value,
}
)
def from_authentication_json(data):
if data is None:
return None
install_id = data["install_id"]
access_token = data["access_token"]
access_token_expires = data["access_token_expires"]
state = AuthenticationState(data["state"])
return Authentication(state, install_id, access_token, access_token_expires)
class Authentication:
def __init__(
self, state, install_id=None, access_token=None, access_token_expires=None
):
self._state = state
self._install_id = str(uuid.uuid4()) if install_id is None else install_id
self._access_token = access_token
self._access_token_expires = access_token_expires
self._parsed_expiration_time = None
if access_token_expires:
self._parsed_expiration_time = parse_datetime(access_token_expires)
@property
def install_id(self):
return self._install_id
@property
def access_token(self):
return self._access_token
@property
def access_token_expires(self):
return self._access_token_expires
@property
def state(self):
return self._state
@state.setter
def state(self, value):
self._state = value
def parsed_expiration_time(self):
return self._parsed_expiration_time
def is_expired(self):
return self._parsed_expiration_time < datetime.now(timezone.utc)
class AuthenticationState(Enum):
REQUIRES_AUTHENTICATION = "requires_authentication"
REQUIRES_VALIDATION = "requires_validation"
AUTHENTICATED = "authenticated"
BAD_PASSWORD = "bad_password" # nosec
class ValidationResult(Enum):
VALIDATED = "validated"
INVALID_VERIFICATION_CODE = "invalid_verification_code"
class AuthenticatorCommon:
def __init__(
self,
api: ApiCommon,
login_method: str | None,
username: str | None,
password: str | None,
install_id: str | None = None,
access_token_cache_file: str | None = None,
access_token_renewal_threshold: timedelta = DEFAULT_RENEWAL_THRESHOLD,
) -> None:
self._api = api
self._login_method = login_method
self._username = username
self._password = password
self._install_id = install_id
self._access_token_cache_file = access_token_cache_file
self._access_token_renewal_threshold = access_token_renewal_threshold
self._authentication = None
def _authentication_from_session_response(
self,
install_id: str,
response_headers: dict[str, Any],
json_dict: dict[str, Any],
) -> Authentication:
access_token = (
response_headers.get(HEADER_ACCESS_TOKEN)
or response_headers[HEADER_AUGUST_ACCESS_TOKEN]
)
access_token_expires = json_dict["expiresAt"]
v_password = json_dict["vPassword"]
v_install_id = json_dict["vInstallId"]
if not v_password:
state = AuthenticationState.BAD_PASSWORD
elif not v_install_id:
state = AuthenticationState.REQUIRES_VALIDATION
else:
state = AuthenticationState.AUTHENTICATED
self._authentication = Authentication(
state, install_id, access_token, access_token_expires
)
return self._authentication
def should_refresh(self):
return self._authentication.state == AuthenticationState.AUTHENTICATED and (
(self._authentication.parsed_expiration_time() - datetime.now(timezone.utc))
< self._access_token_renewal_threshold
)
def _process_refreshed_access_token(self, refreshed_token):
jwt_claims = jwt.decode(refreshed_token, options={"verify_signature": False})
if "exp" not in jwt_claims:
_LOGGER.warning("Did not find expected `exp' claim in JWT")
return self._authentication
new_expiration = datetime.utcfromtimestamp(jwt_claims["exp"]) # noqa: DTZ004
# The yale access api always returns expiresAt in the format
# '%Y-%m-%dT%H:%M:%S.%fZ'
# from the get_session api call
# It is important we store access_token_expires formatted
# the same way for compatibility
self._authentication = Authentication(
self._authentication.state,
install_id=self._authentication.install_id,
access_token=refreshed_token,
access_token_expires=new_expiration.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
)
_LOGGER.info("Successfully refreshed access token")
return self._authentication
|