File: token_manager.py

package info (click to toggle)
python-redis 6.4.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 9,432 kB
  • sloc: python: 60,318; sh: 179; makefile: 128
file content (370 lines) | stat: -rw-r--r-- 12,018 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import asyncio
import logging
import threading
from datetime import datetime, timezone
from time import sleep
from typing import Any, Awaitable, Callable, Union

from redis.auth.err import RequestTokenErr, TokenRenewalErr
from redis.auth.idp import IdentityProviderInterface
from redis.auth.token import TokenResponse

logger = logging.getLogger(__name__)


class CredentialsListener:
    """
    Listeners that will be notified on events related to credentials.
    Accepts callbacks and awaitable callbacks.
    """

    def __init__(self):
        self._on_next = None
        self._on_error = None

    @property
    def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
        return self._on_next

    @on_next.setter
    def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
        self._on_next = callback

    @property
    def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
        return self._on_error

    @on_error.setter
    def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
        self._on_error = callback


class RetryPolicy:
    def __init__(self, max_attempts: int, delay_in_ms: float):
        self.max_attempts = max_attempts
        self.delay_in_ms = delay_in_ms

    def get_max_attempts(self) -> int:
        """
        Retry attempts before exception will be thrown.

        :return: int
        """
        return self.max_attempts

    def get_delay_in_ms(self) -> float:
        """
        Delay between retries in seconds.

        :return: int
        """
        return self.delay_in_ms


class TokenManagerConfig:
    def __init__(
        self,
        expiration_refresh_ratio: float,
        lower_refresh_bound_millis: int,
        token_request_execution_timeout_in_ms: int,
        retry_policy: RetryPolicy,
    ):
        self._expiration_refresh_ratio = expiration_refresh_ratio
        self._lower_refresh_bound_millis = lower_refresh_bound_millis
        self._token_request_execution_timeout_in_ms = (
            token_request_execution_timeout_in_ms
        )
        self._retry_policy = retry_policy

    def get_expiration_refresh_ratio(self) -> float:
        """
        Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
        For example, a value of 0.75 means the token should be refreshed
        when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).

        :return: float
        """

        return self._expiration_refresh_ratio

    def get_lower_refresh_bound_millis(self) -> int:
        """
        Represents the minimum time in milliseconds before token expiration
        to trigger a refresh, in milliseconds.
        This value sets a fixed lower bound for when a token refresh should occur,
        regardless of the token's total lifetime.
        If set to 0 there will be no lower bound and the refresh will be triggered
        based on the expirationRefreshRatio only.

        :return: int
        """
        return self._lower_refresh_bound_millis

    def get_token_request_execution_timeout_in_ms(self) -> int:
        """
        Represents the maximum time in milliseconds to wait
        for a token request to complete.

        :return: int
        """
        return self._token_request_execution_timeout_in_ms

    def get_retry_policy(self) -> RetryPolicy:
        """
        Represents the retry policy for token requests.

        :return: RetryPolicy
        """
        return self._retry_policy


class TokenManager:
    def __init__(
        self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
    ):
        self._idp = identity_provider
        self._config = config
        self._next_timer = None
        self._listener = None
        self._init_timer = None
        self._retries = 0

    def __del__(self):
        logger.info("Token manager are disposed")
        self.stop()

    def start(
        self,
        listener: CredentialsListener,
        skip_initial: bool = False,
    ) -> Callable[[], None]:
        self._listener = listener

        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            # Run loop in a separate thread to unblock main thread.
            loop = asyncio.new_event_loop()
            thread = threading.Thread(
                target=_start_event_loop_in_thread, args=(loop,), daemon=True
            )
            thread.start()

        # Event to block for initial execution.
        init_event = asyncio.Event()
        self._init_timer = loop.call_later(
            0, self._renew_token, skip_initial, init_event
        )
        logger.info("Token manager started")

        # Blocks in thread-safe manner.
        asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
        return self.stop

    async def start_async(
        self,
        listener: CredentialsListener,
        block_for_initial: bool = False,
        initial_delay_in_ms: float = 0,
        skip_initial: bool = False,
    ) -> Callable[[], None]:
        self._listener = listener

        loop = asyncio.get_running_loop()
        init_event = asyncio.Event()

        # Wraps the async callback with async wrapper to schedule with loop.call_later()
        wrapped = _async_to_sync_wrapper(
            loop, self._renew_token_async, skip_initial, init_event
        )
        self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
        logger.info("Token manager started")

        if block_for_initial:
            await init_event.wait()

        return self.stop

    def stop(self):
        if self._init_timer is not None:
            self._init_timer.cancel()
        if self._next_timer is not None:
            self._next_timer.cancel()

    def acquire_token(self, force_refresh=False) -> TokenResponse:
        try:
            token = self._idp.request_token(force_refresh)
        except RequestTokenErr as e:
            if self._retries < self._config.get_retry_policy().get_max_attempts():
                self._retries += 1
                sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
                return self.acquire_token(force_refresh)
            else:
                raise e

        self._retries = 0
        return TokenResponse(token)

    async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
        try:
            token = self._idp.request_token(force_refresh)
        except RequestTokenErr as e:
            if self._retries < self._config.get_retry_policy().get_max_attempts():
                self._retries += 1
                await asyncio.sleep(
                    self._config.get_retry_policy().get_delay_in_ms() / 1000
                )
                return await self.acquire_token_async(force_refresh)
            else:
                raise e

        self._retries = 0
        return TokenResponse(token)

    def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
        delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
        delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
        delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)

        return 0 if delay < 0 else delay / 1000

    def _delay_for_lower_refresh(self, expire_date: float):
        return (
            expire_date
            - self._config.get_lower_refresh_bound_millis()
            - (datetime.now(timezone.utc).timestamp() * 1000)
        )

    def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
        token_ttl = expire_date - issue_date
        refresh_before = token_ttl - (
            token_ttl * self._config.get_expiration_refresh_ratio()
        )

        return (
            expire_date
            - refresh_before
            - (datetime.now(timezone.utc).timestamp() * 1000)
        )

    def _renew_token(
        self, skip_initial: bool = False, init_event: asyncio.Event = None
    ):
        """
        Task to renew token from identity provider.
        Schedules renewal tasks based on token TTL.
        """

        try:
            token_res = self.acquire_token(force_refresh=True)
            delay = self._calculate_renewal_delay(
                token_res.get_token().get_expires_at_ms(),
                token_res.get_token().get_received_at_ms(),
            )

            if token_res.get_token().is_expired():
                raise TokenRenewalErr("Requested token is expired")

            if self._listener.on_next is None:
                logger.warning(
                    "No registered callback for token renewal task. Renewal cancelled"
                )
                return

            if not skip_initial:
                try:
                    self._listener.on_next(token_res.get_token())
                except Exception as e:
                    raise TokenRenewalErr(e)

            if delay <= 0:
                return

            loop = asyncio.get_running_loop()
            self._next_timer = loop.call_later(delay, self._renew_token)
            logger.info(f"Next token renewal scheduled in {delay} seconds")
            return token_res
        except Exception as e:
            if self._listener.on_error is None:
                raise e

            self._listener.on_error(e)
        finally:
            if init_event:
                init_event.set()

    async def _renew_token_async(
        self, skip_initial: bool = False, init_event: asyncio.Event = None
    ):
        """
        Async task to renew tokens from identity provider.
        Schedules renewal tasks based on token TTL.
        """

        try:
            token_res = await self.acquire_token_async(force_refresh=True)
            delay = self._calculate_renewal_delay(
                token_res.get_token().get_expires_at_ms(),
                token_res.get_token().get_received_at_ms(),
            )

            if token_res.get_token().is_expired():
                raise TokenRenewalErr("Requested token is expired")

            if self._listener.on_next is None:
                logger.warning(
                    "No registered callback for token renewal task. Renewal cancelled"
                )
                return

            if not skip_initial:
                try:
                    await self._listener.on_next(token_res.get_token())
                except Exception as e:
                    raise TokenRenewalErr(e)

            if delay <= 0:
                return

            loop = asyncio.get_running_loop()
            wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
            logger.info(f"Next token renewal scheduled in {delay} seconds")
            loop.call_later(delay, wrapped)
        except Exception as e:
            if self._listener.on_error is None:
                raise e

            await self._listener.on_error(e)
        finally:
            if init_event:
                init_event.set()


def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
    """
    Wraps an asynchronous function so it can be used with loop.call_later.

    :param loop: The event loop in which the coroutine will be executed.
    :param coro_func: The coroutine function to wrap.
    :param args: Positional arguments to pass to the coroutine function.
    :param kwargs: Keyword arguments to pass to the coroutine function.
    :return: A regular function suitable for loop.call_later.
    """

    def wrapped():
        # Schedule the coroutine in the event loop
        asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)

    return wrapped


def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
    """
    Starts event loop in a thread.
    Used to be able to schedule tasks using loop.call_later.

    :param event_loop:
    :return:
    """
    asyncio.set_event_loop(event_loop)
    event_loop.run_forever()