File: unique.py

package info (click to toggle)
python-advanced-alchemy 1.4.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 3,708 kB
  • sloc: python: 25,811; makefile: 162; javascript: 123; sh: 4
file content (161 lines) | stat: -rw-r--r-- 5,722 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union

from sqlalchemy import ColumnElement, select
from sqlalchemy.orm import declarative_mixin
from typing_extensions import Self

from advanced_alchemy.exceptions import wrap_sqlalchemy_exception

if TYPE_CHECKING:
    from collections.abc import Hashable, Iterator

    from sqlalchemy import Select
    from sqlalchemy.ext.asyncio import AsyncSession
    from sqlalchemy.ext.asyncio.scoping import async_scoped_session
    from sqlalchemy.orm import Session
    from sqlalchemy.orm.scoping import scoped_session

__all__ = ("UniqueMixin",)


@declarative_mixin
class UniqueMixin:
    """Mixin for instantiating objects while ensuring uniqueness on some field(s).

    This is a slightly modified implementation derived from https://github.com/sqlalchemy/sqlalchemy/wiki/UniqueObject
    """

    @classmethod
    @contextmanager
    def _prevent_autoflush(
        cls,
        session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
    ) -> "Iterator[None]":
        with session.no_autoflush, wrap_sqlalchemy_exception():
            yield

    @classmethod
    def _check_uniqueness(
        cls,
        cache: "Optional[dict[tuple[type[Self], Hashable], Self]]",
        session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
        key: "tuple[type[Self], Hashable]",
        *args: Any,
        **kwargs: Any,
    ) -> "tuple[dict[tuple[type[Self], Hashable], Self], Select[tuple[Self]], Optional[Self]]":
        if cache is None:
            cache = {}
            setattr(session, "_unique_cache", cache)
        statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(2)
        return cache, statement, cache.get(key)

    @classmethod
    async def as_unique_async(
        cls,
        session: "Union[AsyncSession, async_scoped_session[AsyncSession]]",
        *args: Any,
        **kwargs: Any,
    ) -> Self:
        """Instantiate and return a unique object within the provided session based on the given arguments.

        If an object with the same unique identifier already exists in the session, it is returned from the cache.

        Args:
            session (AsyncSession | async_scoped_session[AsyncSession]): SQLAlchemy async session
            *args (Any): Values used to instantiate the instance if no duplicate exists
            **kwargs (Any): Values used to instantiate the instance if no duplicate exists

        Returns:
            Self: The unique object instance.
        """
        key = cls, cls.unique_hash(*args, **kwargs)
        cache, statement, obj = cls._check_uniqueness(
            getattr(session, "_unique_cache", None),
            session,
            key,
            *args,
            **kwargs,
        )
        if obj:
            return obj
        with cls._prevent_autoflush(session):
            if (obj := (await session.execute(statement)).scalar_one_or_none()) is None:
                session.add(obj := cls(*args, **kwargs))
        cache[key] = obj
        return obj

    @classmethod
    def as_unique_sync(
        cls,
        session: "Union[Session, scoped_session[Session]]",
        *args: Any,
        **kwargs: Any,
    ) -> Self:
        """Instantiate and return a unique object within the provided session based on the given arguments.

        If an object with the same unique identifier already exists in the session, it is returned from the cache.

        Args:
            session (Session | scoped_session[Session]): SQLAlchemy sync session
            *args (Any): Values used to instantiate the instance if no duplicate exists
            **kwargs (Any): Values used to instantiate the instance if no duplicate exists

        Returns:
            Self: The unique object instance.
        """
        key = cls, cls.unique_hash(*args, **kwargs)
        cache, statement, obj = cls._check_uniqueness(
            getattr(session, "_unique_cache", None),
            session,
            key,
            *args,
            **kwargs,
        )
        if obj:
            return obj
        with cls._prevent_autoflush(session):
            if (obj := session.execute(statement).scalar_one_or_none()) is None:
                session.add(obj := cls(*args, **kwargs))
        cache[key] = obj
        return obj

    @classmethod
    def unique_hash(cls, *args: Any, **kwargs: Any) -> "Hashable":
        """Generate a unique key based on the provided arguments.

        This method should be implemented in the subclass.


        Args:
            *args (Any): Values passed to the alternate classmethod constructors
            **kwargs (Any): Values passed to the alternate classmethod constructors

        Raises:
            NotImplementedError: If not implemented in the subclass.

        Returns:
            Hashable: Any hashable object.
        """
        msg = "Implement this in subclass"
        raise NotImplementedError(msg)

    @classmethod
    def unique_filter(cls, *args: Any, **kwargs: Any) -> "ColumnElement[bool]":
        """Generate a filter condition for ensuring uniqueness.

        This method should be implemented in the subclass.


        Args:
            *args (Any): Values passed to the alternate classmethod constructors
            **kwargs (Any): Values passed to the alternate classmethod constructors

        Raises:
            NotImplementedError: If not implemented in the subclass.

        Returns:
            ColumnElement[bool]: Filter condition to establish the uniqueness.
        """
        msg = "Implement this in subclass"
        raise NotImplementedError(msg)