File: operations.py

package info (click to toggle)
python-advanced-alchemy 1.8.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,904 kB
  • sloc: python: 36,227; makefile: 153; sh: 4
file content (478 lines) | stat: -rw-r--r-- 19,871 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""Advanced database operations for SQLAlchemy.

This module provides high-performance database operations that extend beyond basic CRUD
functionality. It implements specialized database operations optimized for bulk data
handling and schema management.

The operations module is designed to work seamlessly with SQLAlchemy Core and ORM,
providing efficient implementations for common database operations patterns.

Features
--------
- Cross-database ON CONFLICT/ON DUPLICATE KEY UPDATE operations
- MERGE statement support for Oracle and PostgreSQL 15+

Security
--------
This module constructs SQL statements using database identifiers (table and column names)
that MUST come from trusted sources only. All identifiers should originate from:

- SQLAlchemy model metadata (e.g., Model.__table__)
- Hardcoded strings in application code
- Validated configuration files

Never pass user input directly as table names, column names, or other SQL identifiers.
Data values are properly parameterized using bindparam() to prevent SQL injection.

Notes:
------
This module is designed to be database-agnostic where possible, with specialized
optimizations for specific database backends where appropriate.

See Also:
---------
- :mod:`sqlalchemy.sql.expression` : SQLAlchemy Core expression language
- :mod:`sqlalchemy.orm` : SQLAlchemy ORM functionality
- :mod:`advanced_alchemy.extensions` : Additional database extensions
"""

from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID

from sqlalchemy import Insert, Table, bindparam, literal_column, select, text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.expression import Executable

if TYPE_CHECKING:  # pragma: no cover - typing only
    from sqlalchemy.sql.compiler import SQLCompiler
    from sqlalchemy.sql.elements import ColumnElement

__all__ = ("MergeStatement", "OnConflictUpsert", "validate_identifier")

# Pattern for valid SQL identifiers (conservative - alphanumeric and underscore only)
_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")


def validate_identifier(name: str, identifier_type: str = "identifier") -> str:
    """Validate a SQL identifier to ensure it's safe for use in SQL statements.

    This function provides validation for SQL identifiers
    (table names, column names, etc.) to ensure they contain only safe characters.
    While the operations in this module should only receive identifiers from
    trusted sources, this validation adds an extra layer of security.

    Note: SQL keywords (like 'select', 'insert', etc.) are allowed as they can
    be properly quoted/escaped by SQLAlchemy when used as identifiers.

    Args:
        name: The identifier to validate
        identifier_type: Type of identifier for error messages (e.g., "column", "table")

    Returns:
        The validated identifier

    Raises:
        ValueError: If the identifier is empty or contains invalid characters

    Examples:
        >>> validate_identifier("user_id")
        'user_id'
        >>> validate_identifier("users_table", "table")
        'users_table'
        >>> validate_identifier("select")  # SQL keywords are allowed
        'select'
        >>> validate_identifier(
        ...     "drop table users; --"
        ... )  # Raises ValueError - contains invalid characters
    """
    if not name:
        msg = f"Empty {identifier_type} name provided"
        raise ValueError(msg)

    if not _IDENTIFIER_PATTERN.match(name):
        msg = f"Invalid {identifier_type} name: '{name}'. Only alphanumeric characters and underscores are allowed."
        raise ValueError(msg)

    return name


class MergeStatement(Executable, ClauseElement):
    """A MERGE statement for Oracle and PostgreSQL 15+.

    This provides a high-level interface for MERGE operations that
    can handle both matched and unmatched conditions.
    """

    inherit_cache = True

    def __init__(
        self,
        table: Table,
        source: ClauseElement | str,
        on_condition: ClauseElement,
        when_matched_update: dict[str, Any] | None = None,
        when_not_matched_insert: dict[str, Any] | None = None,
    ) -> None:
        """Initialize a MERGE statement.

        Args:
            table: Target table for the merge operation
            source: Source data (can be a subquery or table)
            on_condition: Condition for matching rows
            when_matched_update: Values to update when rows match
            when_not_matched_insert: Values to insert when rows don't match
        """
        self.table = table
        self.source = source
        self.on_condition = on_condition
        self.when_matched_update = when_matched_update or {}
        self.when_not_matched_insert = when_not_matched_insert or {}


# PostgreSQL version constant
POSTGRES_MERGE_VERSION = 15


@compiles(MergeStatement)
def compile_merge_default(element: MergeStatement, compiler: SQLCompiler, **kwargs: Any) -> str:
    """Default compilation - raises error for unsupported dialects."""
    _ = element, kwargs  # Unused parameters
    dialect_name = compiler.dialect.name
    msg = f"MERGE statement not supported for dialect '{dialect_name}'"
    raise NotImplementedError(msg)


@compiles(MergeStatement, "oracle")
def compile_merge_oracle(element: MergeStatement, compiler: SQLCompiler, **kwargs: Any) -> str:
    """Compile MERGE statement for Oracle."""
    table_name = element.table.name

    if isinstance(element.source, str):
        source_str = element.source
        if source_str.upper().startswith("SELECT") and "FROM DUAL" not in source_str.upper():
            source_str = f"{source_str} FROM DUAL"
        source_clause = f"({source_str})"
    else:
        compiled_source = compiler.process(element.source, **kwargs)
        source_clause = f"({compiled_source})"

    merge_sql = f"MERGE INTO {table_name} tgt USING {source_clause} src ON ("
    merge_sql += compiler.process(element.on_condition, **kwargs)
    merge_sql += ")"

    if element.when_matched_update:
        merge_sql += " WHEN MATCHED THEN UPDATE SET "
        updates = []
        for column, value in element.when_matched_update.items():
            if hasattr(value, "_compiler_dispatch"):
                compiled_value = compiler.process(value, **kwargs)
            else:
                compiled_value = compiler.process(value, **kwargs)
            updates.append(f"{column} = {compiled_value}")  # pyright: ignore
        merge_sql += ", ".join(updates)  # pyright: ignore

    if element.when_not_matched_insert:
        columns = list(element.when_not_matched_insert.keys())
        values = list(element.when_not_matched_insert.values())

        merge_sql += " WHEN NOT MATCHED THEN INSERT ("
        merge_sql += ", ".join(columns)
        merge_sql += ") VALUES ("

        compiled_values = []
        for value in values:
            if hasattr(value, "_compiler_dispatch"):
                compiled_value = compiler.process(value, **kwargs)
            else:
                compiled_value = compiler.process(value, **kwargs)
            compiled_values.append(compiled_value)  # pyright: ignore
        merge_sql += ", ".join(compiled_values)  # pyright: ignore
        merge_sql += ")"

    return merge_sql


@compiles(MergeStatement, "postgresql")
def compile_merge_postgresql(element: MergeStatement, compiler: SQLCompiler, **kwargs: Any) -> str:
    """Compile MERGE statement for PostgreSQL 15+."""
    dialect = compiler.dialect
    if (
        hasattr(dialect, "server_version_info")
        and dialect.server_version_info
        and dialect.server_version_info[0] < POSTGRES_MERGE_VERSION
    ):
        msg = "MERGE statement requires PostgreSQL 15 or higher"
        raise NotImplementedError(msg)

    table_name = element.table.name

    if isinstance(element.source, str):
        # Wrap raw string source and alias as src
        source_clause = f"({element.source}) AS src"
    else:
        # Ensure the compiled source is parenthesized and has a stable alias 'src'
        compiled_source = compiler.process(element.source, **kwargs)
        compiled_trim = compiled_source.strip()
        if compiled_trim.startswith("("):
            # Already parenthesized; check for alias after closing paren
            has_outer_alias = (
                re.search(r"\)\s+(AS\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*$", compiled_trim, re.IGNORECASE) is not None
            )
            source_clause = compiled_trim if has_outer_alias else f"{compiled_trim} AS src"
        else:
            # Not parenthesized: wrap and alias
            source_clause = f"({compiled_trim}) AS src"

    merge_sql = f"MERGE INTO {table_name} AS tgt USING {source_clause} ON ("
    merge_sql += compiler.process(element.on_condition, **kwargs)
    merge_sql += ")"

    if element.when_matched_update:
        merge_sql += " WHEN MATCHED THEN UPDATE SET "
        updates = []
        for column, value in element.when_matched_update.items():
            if hasattr(value, "_compiler_dispatch"):
                compiled_value = compiler.process(value, **kwargs)
            else:
                compiled_value = compiler.process(value, **kwargs)
            updates.append(f"{column} = {compiled_value}")  # pyright: ignore
        merge_sql += ", ".join(updates)  # pyright: ignore

    if element.when_not_matched_insert:
        columns = list(element.when_not_matched_insert.keys())
        values = list(element.when_not_matched_insert.values())

        merge_sql += " WHEN NOT MATCHED THEN INSERT ("
        merge_sql += ", ".join(columns)
        merge_sql += ") VALUES ("

        compiled_values = []
        for value in values:
            if hasattr(value, "_compiler_dispatch"):
                compiled_value = compiler.process(value, **kwargs)
            else:
                compiled_value = compiler.process(value, **kwargs)
            compiled_values.append(compiled_value)  # pyright: ignore
        merge_sql += ", ".join(compiled_values)  # pyright: ignore
        merge_sql += ")"

    return merge_sql


class OnConflictUpsert:
    """Cross-database upsert operation using dialect-specific constructs.

    This class provides a unified interface for upsert operations across
    different database backends using their native ON CONFLICT or
    ON DUPLICATE KEY UPDATE mechanisms.
    """

    @staticmethod
    def supports_native_upsert(dialect_name: str) -> bool:
        """Check if the dialect supports native upsert operations.

        Args:
            dialect_name: Name of the database dialect

        Returns:
            True if native upsert is supported, False otherwise
        """
        return dialect_name in {"postgresql", "cockroachdb", "sqlite", "mysql", "mariadb", "duckdb"}

    @staticmethod
    def create_upsert(
        table: Table,
        values: dict[str, Any],
        conflict_columns: list[str],
        update_columns: list[str] | None = None,
        dialect_name: str | None = None,
        validate_identifiers: bool = False,
    ) -> Insert:
        """Create a dialect-specific upsert statement.

        Args:
            table: Target table for the upsert
            values: Values to insert/update
            conflict_columns: Columns that define the conflict condition
            update_columns: Columns to update on conflict (defaults to all non-conflict columns)
            dialect_name: Database dialect name (auto-detected if not provided)
            validate_identifiers: If True, validate column names for safety (default: False)

        Returns:
            A SQLAlchemy Insert statement with upsert logic

        Raises:
            NotImplementedError: If the dialect doesn't support native upsert
            ValueError: If validate_identifiers is True and invalid identifiers are found
        """
        if validate_identifiers:
            for col in conflict_columns:
                validate_identifier(col, "conflict column")
            if update_columns:
                for col in update_columns:
                    validate_identifier(col, "update column")
            for col in values:
                validate_identifier(col, "column")

        if update_columns is None:
            update_columns = [col for col in values if col not in conflict_columns]

        if dialect_name in {"postgresql", "sqlite", "duckdb"}:
            from sqlalchemy.dialects.postgresql import insert as pg_insert

            pg_insert_stmt = pg_insert(table).values(values)
            return pg_insert_stmt.on_conflict_do_update(
                index_elements=conflict_columns, set_={col: pg_insert_stmt.excluded[col] for col in update_columns}
            )
        if dialect_name == "cockroachdb":
            from sqlalchemy.dialects.postgresql import insert as pg_insert

            pg_insert_stmt = pg_insert(table).values(values)
            return pg_insert_stmt.on_conflict_do_update(
                index_elements=conflict_columns, set_={col: pg_insert_stmt.excluded[col] for col in update_columns}
            )

        if dialect_name in {"mysql", "mariadb"}:
            from sqlalchemy.dialects.mysql import insert as mysql_insert

            mysql_insert_stmt = mysql_insert(table).values(values)
            return mysql_insert_stmt.on_duplicate_key_update(
                **{col: mysql_insert_stmt.inserted[col] for col in update_columns}
            )

        msg = f"Native upsert not supported for dialect '{dialect_name}'"
        raise NotImplementedError(msg)

    @staticmethod
    def create_merge_upsert(  # noqa: C901, PLR0915
        table: Table,
        values: dict[str, Any],
        conflict_columns: list[str],
        update_columns: list[str] | None = None,
        dialect_name: str | None = None,
        validate_identifiers: bool = False,
    ) -> tuple[MergeStatement, dict[str, Any]]:
        """Create a MERGE-based upsert for Oracle/PostgreSQL 15+.

        For Oracle databases, this method automatically generates values for primary key
        columns that have callable defaults (such as UUID generation functions). This is
        necessary because Oracle MERGE statements cannot use Python callable defaults
        directly in the INSERT clause.

        Args:
            table: Target table for the upsert
            values: Values to insert/update
            conflict_columns: Columns that define the matching condition
            update_columns: Columns to update on match (defaults to all non-conflict columns)
            dialect_name: Database dialect name (used to determine Oracle-specific syntax)
            validate_identifiers: If True, validate column names for safety (default: False)

        Returns:
            A tuple of (MergeStatement, additional_params) where additional_params
            contains any generated values (like Oracle UUID primary keys)

        Raises:
            ValueError: If validate_identifiers is True and invalid identifiers are found
        """
        if validate_identifiers:
            for col in conflict_columns:
                validate_identifier(col, "conflict column")
            if update_columns:
                for col in update_columns:
                    validate_identifier(col, "update column")
            for col in values:
                validate_identifier(col, "column")

        if update_columns is None:
            update_columns = [col for col in values if col not in conflict_columns]

        additional_params: dict[str, Any] = {}
        source: ClauseElement | str
        insert_columns: list[str]
        when_not_matched_insert: dict[str, Any]

        if dialect_name == "oracle":
            labeled_columns: list[ColumnElement[Any]] = []
            for key, value in values.items():
                column = table.c[key]
                labeled_columns.append(bindparam(key, value=value, type_=column.type).label(key))

            pk_col_with_seq = None
            for pk_column in table.primary_key.columns:
                if pk_column.name in values or pk_column.default is None:
                    continue
                if callable(getattr(pk_column.default, "arg", None)):
                    try:
                        default_value = pk_column.default.arg(None)  # type: ignore[attr-defined]
                        if isinstance(default_value, UUID):
                            default_value = default_value.hex
                        additional_params[pk_column.name] = default_value
                        labeled_columns.append(
                            bindparam(pk_column.name, value=default_value, type_=pk_column.type).label(pk_column.name)
                        )
                    except (TypeError, AttributeError, ValueError):
                        continue
                elif hasattr(pk_column.default, "next_value"):
                    pk_col_with_seq = pk_column

            # Oracle requires FROM DUAL for SELECT statements without tables
            source_query = select(*labeled_columns)
            # Add FROM DUAL for Oracle
            source_query = source_query.select_from(text("DUAL"))
            source = source_query.subquery("src")
            insert_columns = [label_col.name for label_col in labeled_columns]
            when_not_matched_insert = {col_name: literal_column(f"src.{col_name}") for col_name in insert_columns}
            if pk_col_with_seq is not None:
                insert_columns.append(pk_col_with_seq.name)
                when_not_matched_insert[pk_col_with_seq.name] = cast("Any", pk_col_with_seq.default).next_value()

        elif dialect_name in {"postgresql", "cockroachdb"}:
            labeled_columns = []
            for key, value in values.items():
                column = table.c[key]
                bp = bindparam(f"src_{key}", value=value, type_=column.type)
                labeled_columns.append(bp.label(key))
            source = select(*labeled_columns).subquery("src")
            insert_columns = list(values.keys())
            when_not_matched_insert = {col: literal_column(f"src.{col}") for col in insert_columns}
        else:
            placeholders = ", ".join([f"%({key})s" for key in values])
            col_names = ", ".join(values.keys())
            source = f"(SELECT * FROM (VALUES ({placeholders})) AS src({col_names}))"  # noqa: S608
            insert_columns = list(values.keys())
            when_not_matched_insert = {col: bindparam(col) for col in insert_columns}

        on_conditions = [f"tgt.{col} = src.{col}" for col in conflict_columns]
        on_condition = text(" AND ".join(on_conditions))

        if dialect_name in {"postgresql", "cockroachdb", "oracle"}:
            when_matched_update: dict[str, Any] = {
                col: literal_column(f"src.{col}") for col in update_columns if col in values
            }
        else:
            when_matched_update = {col: bindparam(col) for col in update_columns if col in values}

        # For Oracle, we need to ensure the keys in when_not_matched_insert match the insert_columns
        if dialect_name == "oracle":
            final_insert_mapping = {}
            for col_name in insert_columns:
                if col_name in when_not_matched_insert:
                    final_insert_mapping[col_name] = when_not_matched_insert[col_name]
            when_not_matched_insert = final_insert_mapping

        merge_stmt = MergeStatement(
            table=table,
            source=source,
            on_condition=on_condition,
            when_matched_update=when_matched_update,
            when_not_matched_insert=when_not_matched_insert,
        )

        return merge_stmt, additional_params  # pyright: ignore[reportUnknownVariableType]


# Note: Oracle-specific helper removed; inline logic now handles defaults