1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union
from sqlalchemy import ColumnElement, select
from sqlalchemy.orm import declarative_mixin
from typing_extensions import Self
from advanced_alchemy.exceptions import wrap_sqlalchemy_exception
if TYPE_CHECKING:
from collections.abc import Hashable, Iterator
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import Session
from sqlalchemy.orm.scoping import scoped_session
__all__ = ("UniqueMixin",)
@declarative_mixin
class UniqueMixin:
"""Mixin for instantiating objects while ensuring uniqueness on some field(s).
This is a slightly modified implementation derived from https://github.com/sqlalchemy/sqlalchemy/wiki/UniqueObject
"""
@classmethod
@contextmanager
def _prevent_autoflush(
cls,
session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
) -> "Iterator[None]":
with session.no_autoflush, wrap_sqlalchemy_exception():
yield
@classmethod
def _check_uniqueness(
cls,
cache: "Optional[dict[tuple[type[Self], Hashable], Self]]",
session: "Union[AsyncSession, async_scoped_session[AsyncSession], Session, scoped_session[Session]]",
key: "tuple[type[Self], Hashable]",
*args: Any,
**kwargs: Any,
) -> "tuple[dict[tuple[type[Self], Hashable], Self], Select[tuple[Self]], Optional[Self]]":
if cache is None:
cache = {}
setattr(session, "_unique_cache", cache)
statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(2)
return cache, statement, cache.get(key)
@classmethod
async def as_unique_async(
cls,
session: "Union[AsyncSession, async_scoped_session[AsyncSession]]",
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (AsyncSession | async_scoped_session[AsyncSession]): SQLAlchemy async session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := (await session.execute(statement)).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj
@classmethod
def as_unique_sync(
cls,
session: "Union[Session, scoped_session[Session]]",
*args: Any,
**kwargs: Any,
) -> Self:
"""Instantiate and return a unique object within the provided session based on the given arguments.
If an object with the same unique identifier already exists in the session, it is returned from the cache.
Args:
session (Session | scoped_session[Session]): SQLAlchemy sync session
*args (Any): Values used to instantiate the instance if no duplicate exists
**kwargs (Any): Values used to instantiate the instance if no duplicate exists
Returns:
Self: The unique object instance.
"""
key = cls, cls.unique_hash(*args, **kwargs)
cache, statement, obj = cls._check_uniqueness(
getattr(session, "_unique_cache", None),
session,
key,
*args,
**kwargs,
)
if obj:
return obj
with cls._prevent_autoflush(session):
if (obj := session.execute(statement).scalar_one_or_none()) is None:
session.add(obj := cls(*args, **kwargs))
cache[key] = obj
return obj
@classmethod
def unique_hash(cls, *args: Any, **kwargs: Any) -> "Hashable":
"""Generate a unique key based on the provided arguments.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
Hashable: Any hashable object.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
@classmethod
def unique_filter(cls, *args: Any, **kwargs: Any) -> "ColumnElement[bool]":
"""Generate a filter condition for ensuring uniqueness.
This method should be implemented in the subclass.
Args:
*args (Any): Values passed to the alternate classmethod constructors
**kwargs (Any): Values passed to the alternate classmethod constructors
Raises:
NotImplementedError: If not implemented in the subclass.
Returns:
ColumnElement[bool]: Filter condition to establish the uniqueness.
"""
msg = "Implement this in subclass"
raise NotImplementedError(msg)
|