File: async_getter_lock.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (62 lines) | stat: -rw-r--r-- 1,965 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Any
import functools

from mautrix import __optional_imports__

if __optional_imports__:
    from typing import Awaitable, Callable, ParamSpec

    Param = ParamSpec("Param")
    Func = Callable[Param, Awaitable[Any]]


def async_getter_lock(fn: Func) -> Func:
    """
    A utility decorator for locking async getters that have caches
    (preventing race conditions between cache check and e.g. async database actions).

    The class must have an ```_async_get_locks`` defaultdict that contains :class:`asyncio.Lock`s
    (see example for exact definition). Non-cache-affecting arguments should be only passed as
    keyword args.

    Args:
        fn: The function to decorate.

    Returns:
        The decorated function.

    Examples:
        >>> import asyncio
        >>> from collections import defaultdict
        >>> class User:
        ...   _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
        ...   db: Any
        ...   cache: dict[str, User]
        ...   @classmethod
        ...   @async_getter_lock
        ...   async def get(cls, id: str, *, create: bool = False) -> User | None:
        ...     try:
        ...       return cls.cache[id]
        ...     except KeyError:
        ...       pass
        ...     user = await cls.db.fetch_user(id)
        ...     if user:
        ...       return user
        ...     elif create:
        ...       return await cls.db.create_user(id)
        ...     return None
    """

    @functools.wraps(fn)
    async def wrapper(cls, *args, **kwargs) -> Any:
        async with cls._async_get_locks[args]:
            return await fn(cls, *args, **kwargs)

    return wrapper