File: base.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (248 lines) | stat: -rw-r--r-- 8,280 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Copyright (c) 2022 Tulir Asokan
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Type, TypeVar, cast
from contextlib import contextmanager

from sqlalchemy import Constraint, Table
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.ext.declarative import as_declarative, declarative_base
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.sql.expression import ClauseElement, Select, and_

if TYPE_CHECKING:
    from sqlalchemy.engine.result import ResultProxy, RowProxy

T = TypeVar("T", bound="BaseClass")


class BaseClass:
    """
    Base class for SQLAlchemy models. Provides SQLAlchemy declarative base features and some
    additional utilities.

    .. deprecated:: 0.15.0
       The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy.
    """

    __tablename__: str

    db: Engine
    t: Table
    __table__: Table
    c: ImmutableColumnCollection
    column_names: List[str]

    @classmethod
    def bind(cls, db_engine: Engine) -> None:
        cls.db = db_engine
        cls.t = cls.__table__
        cls.c = cls.t.columns
        cls.column_names = cls.c.keys()

    @classmethod
    def copy(
        cls, bind: Optional[Engine] = None, rebase: Optional[declarative_base] = None
    ) -> Type[T]:
        copy = cast(Type[T], type(cls.__name__, (cls, rebase) if rebase else (cls,), {}))
        if bind is not None:
            copy.bind(db_engine=bind)
        return copy

    @classmethod
    def _one_or_none(cls: Type[T], rows: "ResultProxy") -> Optional[T]:
        """
        Try scanning one row from a ResultProxy and return ``None`` if it fails.

        Args:
            rows: The SQLAlchemy result to scan.

        Returns:
            The scanned object, or ``None`` if there were no rows.
        """
        try:
            return cls.scan(next(rows))
        except StopIteration:
            return None

    @classmethod
    def _all(cls: Type[T], rows: "ResultProxy") -> Iterator[T]:
        """
        Scan all rows from a ResultProxy.

        Args:
            rows: The SQLAlchemy result to scan.

        Yields:
            Each row scanned with :meth:`scan`
        """
        for row in rows:
            yield cls.scan(row)

    @classmethod
    def scan(cls: Type[T], row: "RowProxy") -> T:
        """
        Read the data from a row into an object.

        Args:
            row: The RowProxy object.

        Returns:
            An object containing the information in the row.
        """
        return cls(**dict(zip(cls.column_names, row)))

    @classmethod
    def _make_simple_select(cls: Type[T], *args: ClauseElement) -> Select:
        """
        Create a simple ``SELECT * FROM table WHERE <args>`` statement.

        Args:
            *args: The WHERE clauses. If there are many elements, they're joined with AND.

        Returns:
            The SQLAlchemy SELECT statement object.
        """
        if len(args) > 1:
            return cls.t.select().where(and_(*args))
        elif len(args) == 1:
            return cls.t.select().where(args[0])
        else:
            return cls.t.select()

    @classmethod
    def _select_all(cls: Type[T], *args: ClauseElement) -> Iterator[T]:
        """
        Select all rows with given conditions. This is intended to be used by table-specific
        select methods.

        Args:
            *args: The WHERE clauses. If there are many elements, they're joined with AND.

        Yields:
            The objects representing the rows read with :meth:`scan`
        """
        yield from cls._all(cls.db.execute(cls._make_simple_select(*args)))

    @classmethod
    def _select_one_or_none(cls: Type[T], *args: ClauseElement) -> T:
        """
        Select one row with given conditions. If no row is found, return ``None``. This is intended
        to be used by table-specific select methods.

        Args:
            *args: The WHERE clauses. If there are many elements, they're joined with AND.

        Returns:
            The object representing the matched row read with :meth:`scan`, or ``None`` if no rows
            matched.
        """
        return cls._one_or_none(cls.db.execute(cls._make_simple_select(*args)))

    def _constraint_to_clause(self, constraint: Constraint) -> ClauseElement:
        return and_(
            *[column == self.__dict__[name] for name, column in constraint.columns.items()]
        )

    @property
    def _edit_identity(self: T) -> ClauseElement:
        """The SQLAlchemy WHERE clause used for editing and deleting individual rows.
        Usually AND of primary keys."""
        return self._constraint_to_clause(self.t.primary_key)

    def edit(self: T, *, _update_values: bool = True, **values) -> None:
        """
        Edit this row.

        Args:
            _update_values: Whether or not the values in memory should be updated as well as the
                            values in the database.
            **values: The values to change.
        """
        with self.db.begin() as conn:
            conn.execute(self.t.update().where(self._edit_identity).values(**values))
        if _update_values:
            for key, value in values.items():
                setattr(self, key, value)

    @contextmanager
    def edit_mode(self: T) -> None:
        """
        Edit this row in a fancy context manager way. This stores the current edit identity, then
        yields to the context manager and finally puts the new values into the row using the old
        edit identity in the WHERE clause.

        >>> class TableClass(Base):
        ...     ...
        >>> db_instance = TableClass(id="something")

        >>> with db_instance.edit_mode():
        ...     db_instance.id = "new_id"
        """
        old_identity = self._edit_identity
        yield old_identity
        with self.db.begin() as conn:
            conn.execute(self.t.update().where(old_identity).values(**self._insert_values))

    def delete(self: T) -> None:
        """Delete this row."""
        with self.db.begin() as conn:
            conn.execute(self.t.delete().where(self._edit_identity))

    @property
    def _insert_values(self: T) -> Dict[str, Any]:
        """Values for inserts. Generally you want all the values in the table."""
        return {
            column_name: self.__dict__[column_name]
            for column_name in self.column_names
            if column_name in self.__dict__
        }

    def insert(self) -> None:
        with self.db.begin() as conn:
            conn.execute(self.t.insert().values(**self._insert_values))

    @property
    def _upsert_values(self: T) -> Dict[str, Any]:
        """The values to set when an upsert-insert conflicts and moves to the update part."""
        return self._insert_values

    def _upsert_postgres(self: T, conn: Connection) -> None:
        conn.execute(
            pg_insert(self.t)
            .values(**self._insert_values)
            .on_conflict_do_update(constraint=self.t.primary_key, set_=self._upsert_values)
        )

    def _upsert_sqlite(self: T, conn: Connection) -> None:
        conn.execute(self.t.insert().values(**self._insert_values).prefix_with("OR REPLACE"))

    def _upsert_generic(self: T, conn: Connection):
        conn.execute(self.t.delete().where(self._edit_identity))
        conn.execute(self.t.insert().values(**self._insert_values))

    def upsert(self: T) -> None:
        with self.db.begin() as conn:
            if self.db.dialect.name == "postgresql":
                self._upsert_postgres(conn)
            elif self.db.dialect.name == "sqlite":
                self._upsert_sqlite(conn)
            else:
                self._upsert_generic(conn)

    def __iter__(self):
        for key in self.column_names:
            yield self.__dict__[key]


@as_declarative()
class Base(BaseClass):
    """
    .. deprecated:: 0.15.0
       The :mod:`mautrix.util.async_db` utility is now recommended over SQLAlchemy.
    """

    pass