1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
|
"""Advanced database operations for SQLAlchemy.
This module provides high-performance database operations that extend beyond basic CRUD
functionality. It implements specialized database operations optimized for bulk data
handling and schema management.
The operations module is designed to work seamlessly with SQLAlchemy Core and ORM,
providing efficient implementations for common database operations patterns.
Features
--------
- Cross-database ON CONFLICT/ON DUPLICATE KEY UPDATE operations
- MERGE statement support for Oracle and PostgreSQL 15+
Security
--------
This module constructs SQL statements using database identifiers (table and column names)
that MUST come from trusted sources only. All identifiers should originate from:
- SQLAlchemy model metadata (e.g., Model.__table__)
- Hardcoded strings in application code
- Validated configuration files
Never pass user input directly as table names, column names, or other SQL identifiers.
Data values are properly parameterized using bindparam() to prevent SQL injection.
Notes:
------
This module is designed to be database-agnostic where possible, with specialized
optimizations for specific database backends where appropriate.
See Also:
---------
- :mod:`sqlalchemy.sql.expression` : SQLAlchemy Core expression language
- :mod:`sqlalchemy.orm` : SQLAlchemy ORM functionality
- :mod:`advanced_alchemy.extensions` : Additional database extensions
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from sqlalchemy import Insert, Table, bindparam, literal_column, select, text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.expression import Executable
if TYPE_CHECKING: # pragma: no cover - typing only
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.elements import ColumnElement
__all__ = ("MergeStatement", "OnConflictUpsert", "validate_identifier")
# Pattern for valid SQL identifiers (conservative - alphanumeric and underscore only)
_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def validate_identifier(name: str, identifier_type: str = "identifier") -> str:
"""Validate a SQL identifier to ensure it's safe for use in SQL statements.
This function provides validation for SQL identifiers
(table names, column names, etc.) to ensure they contain only safe characters.
While the operations in this module should only receive identifiers from
trusted sources, this validation adds an extra layer of security.
Note: SQL keywords (like 'select', 'insert', etc.) are allowed as they can
be properly quoted/escaped by SQLAlchemy when used as identifiers.
Args:
name: The identifier to validate
identifier_type: Type of identifier for error messages (e.g., "column", "table")
Returns:
The validated identifier
Raises:
ValueError: If the identifier is empty or contains invalid characters
Examples:
>>> validate_identifier("user_id")
'user_id'
>>> validate_identifier("users_table", "table")
'users_table'
>>> validate_identifier("select") # SQL keywords are allowed
'select'
>>> validate_identifier(
... "drop table users; --"
... ) # Raises ValueError - contains invalid characters
"""
if not name:
msg = f"Empty {identifier_type} name provided"
raise ValueError(msg)
if not _IDENTIFIER_PATTERN.match(name):
msg = f"Invalid {identifier_type} name: '{name}'. Only alphanumeric characters and underscores are allowed."
raise ValueError(msg)
return name
class MergeStatement(Executable, ClauseElement):
"""A MERGE statement for Oracle and PostgreSQL 15+.
This provides a high-level interface for MERGE operations that
can handle both matched and unmatched conditions.
"""
inherit_cache = True
def __init__(
self,
table: Table,
source: ClauseElement | str,
on_condition: ClauseElement,
when_matched_update: dict[str, Any] | None = None,
when_not_matched_insert: dict[str, Any] | None = None,
) -> None:
"""Initialize a MERGE statement.
Args:
table: Target table for the merge operation
source: Source data (can be a subquery or table)
on_condition: Condition for matching rows
when_matched_update: Values to update when rows match
when_not_matched_insert: Values to insert when rows don't match
"""
self.table = table
self.source = source
self.on_condition = on_condition
self.when_matched_update = when_matched_update or {}
self.when_not_matched_insert = when_not_matched_insert or {}
# PostgreSQL version constant
POSTGRES_MERGE_VERSION = 15
@compiles(MergeStatement)
def compile_merge_default(element: MergeStatement, compiler: SQLCompiler, **kwargs: Any) -> str:
"""Default compilation - raises error for unsupported dialects."""
_ = element, kwargs # Unused parameters
dialect_name = compiler.dialect.name
msg = f"MERGE statement not supported for dialect '{dialect_name}'"
raise NotImplementedError(msg)
@compiles(MergeStatement, "oracle")
def compile_merge_oracle(element: MergeStatement, compiler: SQLCompiler, **kwargs: Any) -> str:
"""Compile MERGE statement for Oracle."""
table_name = element.table.name
if isinstance(element.source, str):
source_str = element.source
if source_str.upper().startswith("SELECT") and "FROM DUAL" not in source_str.upper():
source_str = f"{source_str} FROM DUAL"
source_clause = f"({source_str})"
else:
compiled_source = compiler.process(element.source, **kwargs)
source_clause = f"({compiled_source})"
merge_sql = f"MERGE INTO {table_name} tgt USING {source_clause} src ON ("
merge_sql += compiler.process(element.on_condition, **kwargs)
merge_sql += ")"
if element.when_matched_update:
merge_sql += " WHEN MATCHED THEN UPDATE SET "
updates = []
for column, value in element.when_matched_update.items():
if hasattr(value, "_compiler_dispatch"):
compiled_value = compiler.process(value, **kwargs)
else:
compiled_value = compiler.process(value, **kwargs)
updates.append(f"{column} = {compiled_value}") # pyright: ignore
merge_sql += ", ".join(updates) # pyright: ignore
if element.when_not_matched_insert:
columns = list(element.when_not_matched_insert.keys())
values = list(element.when_not_matched_insert.values())
merge_sql += " WHEN NOT MATCHED THEN INSERT ("
merge_sql += ", ".join(columns)
merge_sql += ") VALUES ("
compiled_values = []
for value in values:
if hasattr(value, "_compiler_dispatch"):
compiled_value = compiler.process(value, **kwargs)
else:
compiled_value = compiler.process(value, **kwargs)
compiled_values.append(compiled_value) # pyright: ignore
merge_sql += ", ".join(compiled_values) # pyright: ignore
merge_sql += ")"
return merge_sql
@compiles(MergeStatement, "postgresql")
def compile_merge_postgresql(element: MergeStatement, compiler: SQLCompiler, **kwargs: Any) -> str:
"""Compile MERGE statement for PostgreSQL 15+."""
dialect = compiler.dialect
if (
hasattr(dialect, "server_version_info")
and dialect.server_version_info
and dialect.server_version_info[0] < POSTGRES_MERGE_VERSION
):
msg = "MERGE statement requires PostgreSQL 15 or higher"
raise NotImplementedError(msg)
table_name = element.table.name
if isinstance(element.source, str):
# Wrap raw string source and alias as src
source_clause = f"({element.source}) AS src"
else:
# Ensure the compiled source is parenthesized and has a stable alias 'src'
compiled_source = compiler.process(element.source, **kwargs)
compiled_trim = compiled_source.strip()
if compiled_trim.startswith("("):
# Already parenthesized; check for alias after closing paren
has_outer_alias = (
re.search(r"\)\s+(AS\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*$", compiled_trim, re.IGNORECASE) is not None
)
source_clause = compiled_trim if has_outer_alias else f"{compiled_trim} AS src"
else:
# Not parenthesized: wrap and alias
source_clause = f"({compiled_trim}) AS src"
merge_sql = f"MERGE INTO {table_name} AS tgt USING {source_clause} ON ("
merge_sql += compiler.process(element.on_condition, **kwargs)
merge_sql += ")"
if element.when_matched_update:
merge_sql += " WHEN MATCHED THEN UPDATE SET "
updates = []
for column, value in element.when_matched_update.items():
if hasattr(value, "_compiler_dispatch"):
compiled_value = compiler.process(value, **kwargs)
else:
compiled_value = compiler.process(value, **kwargs)
updates.append(f"{column} = {compiled_value}") # pyright: ignore
merge_sql += ", ".join(updates) # pyright: ignore
if element.when_not_matched_insert:
columns = list(element.when_not_matched_insert.keys())
values = list(element.when_not_matched_insert.values())
merge_sql += " WHEN NOT MATCHED THEN INSERT ("
merge_sql += ", ".join(columns)
merge_sql += ") VALUES ("
compiled_values = []
for value in values:
if hasattr(value, "_compiler_dispatch"):
compiled_value = compiler.process(value, **kwargs)
else:
compiled_value = compiler.process(value, **kwargs)
compiled_values.append(compiled_value) # pyright: ignore
merge_sql += ", ".join(compiled_values) # pyright: ignore
merge_sql += ")"
return merge_sql
class OnConflictUpsert:
"""Cross-database upsert operation using dialect-specific constructs.
This class provides a unified interface for upsert operations across
different database backends using their native ON CONFLICT or
ON DUPLICATE KEY UPDATE mechanisms.
"""
@staticmethod
def supports_native_upsert(dialect_name: str) -> bool:
"""Check if the dialect supports native upsert operations.
Args:
dialect_name: Name of the database dialect
Returns:
True if native upsert is supported, False otherwise
"""
return dialect_name in {"postgresql", "cockroachdb", "sqlite", "mysql", "mariadb", "duckdb"}
@staticmethod
def create_upsert(
table: Table,
values: dict[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
dialect_name: str | None = None,
validate_identifiers: bool = False,
) -> Insert:
"""Create a dialect-specific upsert statement.
Args:
table: Target table for the upsert
values: Values to insert/update
conflict_columns: Columns that define the conflict condition
update_columns: Columns to update on conflict (defaults to all non-conflict columns)
dialect_name: Database dialect name (auto-detected if not provided)
validate_identifiers: If True, validate column names for safety (default: False)
Returns:
A SQLAlchemy Insert statement with upsert logic
Raises:
NotImplementedError: If the dialect doesn't support native upsert
ValueError: If validate_identifiers is True and invalid identifiers are found
"""
if validate_identifiers:
for col in conflict_columns:
validate_identifier(col, "conflict column")
if update_columns:
for col in update_columns:
validate_identifier(col, "update column")
for col in values:
validate_identifier(col, "column")
if update_columns is None:
update_columns = [col for col in values if col not in conflict_columns]
if dialect_name in {"postgresql", "sqlite", "duckdb"}:
from sqlalchemy.dialects.postgresql import insert as pg_insert
pg_insert_stmt = pg_insert(table).values(values)
return pg_insert_stmt.on_conflict_do_update(
index_elements=conflict_columns, set_={col: pg_insert_stmt.excluded[col] for col in update_columns}
)
if dialect_name == "cockroachdb":
from sqlalchemy.dialects.postgresql import insert as pg_insert
pg_insert_stmt = pg_insert(table).values(values)
return pg_insert_stmt.on_conflict_do_update(
index_elements=conflict_columns, set_={col: pg_insert_stmt.excluded[col] for col in update_columns}
)
if dialect_name in {"mysql", "mariadb"}:
from sqlalchemy.dialects.mysql import insert as mysql_insert
mysql_insert_stmt = mysql_insert(table).values(values)
return mysql_insert_stmt.on_duplicate_key_update(
**{col: mysql_insert_stmt.inserted[col] for col in update_columns}
)
msg = f"Native upsert not supported for dialect '{dialect_name}'"
raise NotImplementedError(msg)
@staticmethod
def create_merge_upsert( # noqa: C901, PLR0915
table: Table,
values: dict[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
dialect_name: str | None = None,
validate_identifiers: bool = False,
) -> tuple[MergeStatement, dict[str, Any]]:
"""Create a MERGE-based upsert for Oracle/PostgreSQL 15+.
For Oracle databases, this method automatically generates values for primary key
columns that have callable defaults (such as UUID generation functions). This is
necessary because Oracle MERGE statements cannot use Python callable defaults
directly in the INSERT clause.
Args:
table: Target table for the upsert
values: Values to insert/update
conflict_columns: Columns that define the matching condition
update_columns: Columns to update on match (defaults to all non-conflict columns)
dialect_name: Database dialect name (used to determine Oracle-specific syntax)
validate_identifiers: If True, validate column names for safety (default: False)
Returns:
A tuple of (MergeStatement, additional_params) where additional_params
contains any generated values (like Oracle UUID primary keys)
Raises:
ValueError: If validate_identifiers is True and invalid identifiers are found
"""
if validate_identifiers:
for col in conflict_columns:
validate_identifier(col, "conflict column")
if update_columns:
for col in update_columns:
validate_identifier(col, "update column")
for col in values:
validate_identifier(col, "column")
if update_columns is None:
update_columns = [col for col in values if col not in conflict_columns]
additional_params: dict[str, Any] = {}
source: ClauseElement | str
insert_columns: list[str]
when_not_matched_insert: dict[str, Any]
if dialect_name == "oracle":
labeled_columns: list[ColumnElement[Any]] = []
for key, value in values.items():
column = table.c[key]
labeled_columns.append(bindparam(key, value=value, type_=column.type).label(key))
pk_col_with_seq = None
for pk_column in table.primary_key.columns:
if pk_column.name in values or pk_column.default is None:
continue
if callable(getattr(pk_column.default, "arg", None)):
try:
default_value = pk_column.default.arg(None) # type: ignore[attr-defined]
if isinstance(default_value, UUID):
default_value = default_value.hex
additional_params[pk_column.name] = default_value
labeled_columns.append(
bindparam(pk_column.name, value=default_value, type_=pk_column.type).label(pk_column.name)
)
except (TypeError, AttributeError, ValueError):
continue
elif hasattr(pk_column.default, "next_value"):
pk_col_with_seq = pk_column
# Oracle requires FROM DUAL for SELECT statements without tables
source_query = select(*labeled_columns)
# Add FROM DUAL for Oracle
source_query = source_query.select_from(text("DUAL"))
source = source_query.subquery("src")
insert_columns = [label_col.name for label_col in labeled_columns]
when_not_matched_insert = {col_name: literal_column(f"src.{col_name}") for col_name in insert_columns}
if pk_col_with_seq is not None:
insert_columns.append(pk_col_with_seq.name)
when_not_matched_insert[pk_col_with_seq.name] = cast("Any", pk_col_with_seq.default).next_value()
elif dialect_name in {"postgresql", "cockroachdb"}:
labeled_columns = []
for key, value in values.items():
column = table.c[key]
bp = bindparam(f"src_{key}", value=value, type_=column.type)
labeled_columns.append(bp.label(key))
source = select(*labeled_columns).subquery("src")
insert_columns = list(values.keys())
when_not_matched_insert = {col: literal_column(f"src.{col}") for col in insert_columns}
else:
placeholders = ", ".join([f"%({key})s" for key in values])
col_names = ", ".join(values.keys())
source = f"(SELECT * FROM (VALUES ({placeholders})) AS src({col_names}))" # noqa: S608
insert_columns = list(values.keys())
when_not_matched_insert = {col: bindparam(col) for col in insert_columns}
on_conditions = [f"tgt.{col} = src.{col}" for col in conflict_columns]
on_condition = text(" AND ".join(on_conditions))
if dialect_name in {"postgresql", "cockroachdb", "oracle"}:
when_matched_update: dict[str, Any] = {
col: literal_column(f"src.{col}") for col in update_columns if col in values
}
else:
when_matched_update = {col: bindparam(col) for col in update_columns if col in values}
# For Oracle, we need to ensure the keys in when_not_matched_insert match the insert_columns
if dialect_name == "oracle":
final_insert_mapping = {}
for col_name in insert_columns:
if col_name in when_not_matched_insert:
final_insert_mapping[col_name] = when_not_matched_insert[col_name]
when_not_matched_insert = final_insert_mapping
merge_stmt = MergeStatement(
table=table,
source=source,
on_condition=on_condition,
when_matched_update=when_matched_update,
when_not_matched_insert=when_not_matched_insert,
)
return merge_stmt, additional_params # pyright: ignore[reportUnknownVariableType]
# Note: Oracle-specific helper removed; inline logic now handles defaults
|