File: plugin.py

package info (click to toggle)
python-advanced-alchemy 1.8.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,904 kB
  • sloc: python: 36,227; makefile: 153; sh: 4
file content (163 lines) | stat: -rw-r--r-- 6,824 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import contextlib
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union, cast

from litestar.di import Provide
from litestar.dto import DTOData
from litestar.params import Dependency, Parameter
from litestar.plugins import CLIPlugin, InitPluginProtocol
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session

from advanced_alchemy.exceptions import ImproperConfigurationError, RepositoryError
from advanced_alchemy.extensions.litestar.exception_handler import exception_to_http_response
from advanced_alchemy.extensions.litestar.plugins import _slots_base
from advanced_alchemy.filters import (
    BeforeAfter,
    CollectionFilter,
    ComparisonFilter,
    ExistsFilter,
    FilterGroup,
    FilterMap,
    FilterTypes,
    InAnyFilter,
    LimitOffset,
    LogicalOperatorMap,
    MultiFilter,
    NotExistsFilter,
    NotInCollectionFilter,
    NotInSearchFilter,
    OnBeforeAfter,
    OrderBy,
    SearchFilter,
    StatementFilter,
    StatementTypeT,
)
from advanced_alchemy.service import ModelDictListT, ModelDictT, ModelDTOT, ModelOrRowMappingT, ModelT, OffsetPagination

if TYPE_CHECKING:
    from click import Group
    from litestar.config.app import AppConfig
    from litestar.types import BeforeMessageSendHookHandler

    from advanced_alchemy.extensions.litestar.plugins.init.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig

__all__ = ("SQLAlchemyInitPlugin",)

signature_namespace_values: dict[str, Any] = {
    "BeforeAfter": BeforeAfter,
    "OnBeforeAfter": OnBeforeAfter,
    "CollectionFilter": CollectionFilter,
    "LimitOffset": LimitOffset,
    "OrderBy": OrderBy,
    "SearchFilter": SearchFilter,
    "NotInCollectionFilter": NotInCollectionFilter,
    "NotInSearchFilter": NotInSearchFilter,
    "FilterTypes": FilterTypes,
    "OffsetPagination": OffsetPagination,
    "ExistsFilter": ExistsFilter,
    "Parameter": Parameter,
    "Dependency": Dependency,
    "DTOData": DTOData,
    "Sequence": Sequence,
    "ModelT": ModelT,
    "ModelDictT": ModelDictT,
    "ModelDTOT": ModelDTOT,
    "ModelDictListT": ModelDictListT,
    "ModelOrRowMappingT": ModelOrRowMappingT,
    "Session": Session,
    "scoped_session": scoped_session,
    "AsyncSession": AsyncSession,
    "async_scoped_session": async_scoped_session,
    "FilterGroup": FilterGroup,
    "NotExistsFilter": NotExistsFilter,
    "MultiFilter": MultiFilter,
    "ComparisonFilter": ComparisonFilter,
    "StatementTypeT": StatementTypeT,
    "StatementFilter": StatementFilter,
    "LogicalOperatorMap": LogicalOperatorMap,
    "InAnyFilter": InAnyFilter,
    "FilterMap": FilterMap,
}


class SQLAlchemyInitPlugin(InitPluginProtocol, CLIPlugin, _slots_base.SlotsBase):
    """SQLAlchemy application lifecycle configuration."""

    def __init__(
        self,
        config: Union[
            "SQLAlchemyAsyncConfig",
            "SQLAlchemySyncConfig",
            "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]",
        ],
    ) -> None:
        """Initialize ``SQLAlchemyPlugin``.

        Args:
            config: configure DB connection and hook handlers and dependencies.
        """
        self._config = config

    @property
    def config(self) -> "Sequence[Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]]":
        return self._config if isinstance(self._config, Sequence) else [self._config]

    def on_cli_init(self, cli: "Group") -> None:
        from advanced_alchemy.extensions.litestar.cli import database_group

        cli.add_command(database_group)

    def _validate_config(self) -> None:
        configs = self._config if isinstance(self._config, Sequence) else [self._config]
        scope_keys = {config.session_scope_key for config in configs}
        engine_keys = {config.engine_dependency_key for config in configs}
        session_keys = {config.session_dependency_key for config in configs}
        if len(configs) > 1 and any(len(i) != len(configs) for i in (scope_keys, engine_keys, session_keys)):
            raise ImproperConfigurationError(
                detail="When using multiple configurations, please ensure the `session_dependency_key` and `engine_dependency_key` settings are unique across all configs.  Additionally, iF you are using a custom `before_send` handler, ensure `session_scope_key` is unique.",
            )

    def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
        """Configure application for use with SQLAlchemy.

        Args:
            app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
        """
        self._validate_config()
        with contextlib.suppress(ImportError):
            from asyncpg.pgproto import pgproto  # pyright: ignore[reportMissingImports]

            signature_namespace_values.update({"pgproto.UUID": pgproto.UUID})
            app_config.type_encoders = {pgproto.UUID: str, **(app_config.type_encoders or {})}
        with contextlib.suppress(ImportError):
            import uuid_utils  # pyright: ignore[reportMissingImports]

            signature_namespace_values.update({"uuid_utils.UUID": uuid_utils.UUID})  # pyright: ignore[reportUnknownMemberType]
            app_config.type_encoders = {uuid_utils.UUID: str, **(app_config.type_encoders or {})}  # pyright: ignore[reportUnknownMemberType]
            app_config.type_decoders = [
                (lambda x: x is uuid_utils.UUID, lambda t, v: t(str(v))),  # pyright: ignore[reportUnknownMemberType]
                *(app_config.type_decoders or []),
            ]
        configure_exception_handler = False
        for config in self.config:
            if config.set_default_exception_handler:
                configure_exception_handler = True
            signature_namespace_values.update(config.signature_namespace)
            app_config.lifespan.append(config.lifespan)  # pyright: ignore[reportUnknownMemberType]

            app_config.dependencies.update(
                {
                    config.engine_dependency_key: Provide(config.provide_engine, sync_to_thread=False),
                    config.session_dependency_key: Provide(config.provide_session, sync_to_thread=False),
                },
            )
            app_config.before_send.append(cast("BeforeMessageSendHookHandler", config.before_send_handler))
        app_config.signature_namespace.update(signature_namespace_values)
        if configure_exception_handler and not any(
            isinstance(exc, int) or issubclass(exc, RepositoryError)
            for exc in app_config.exception_handlers  # pyright: ignore[reportUnknownMemberType]
        ):
            app_config.exception_handlers.update({RepositoryError: exception_to_http_response})  # pyright: ignore[reportUnknownMemberType]

        return app_config