1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
|
"""
Functions for runtime-configuration of triggers, such as ignoring
them or dynamically setting the search path.
"""
from __future__ import annotations
import contextlib
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, List, Union
from django.db import connections
from pgtrigger import registry, utils
if utils.psycopg_maj_version == 2:
import psycopg2.extensions
import psycopg2.sql as psycopg_sql
elif utils.psycopg_maj_version == 3:
import psycopg.pq
import psycopg.sql as psycopg_sql
else:
raise AssertionError
if TYPE_CHECKING:
from django.db.backends.utils import CursorWrapper
from typing_extensions import TypeAlias
from pgtrigger import Timing
_Query: "TypeAlias" = "str | bytes | psycopg_sql.SQL | psycopg_sql.Composed"
# All triggers currently being ignored
_ignore = threading.local()
# All schemas in the search path
_schema = threading.local()
def _query_to_str(query: _Query, cursor: CursorWrapper) -> str:
if isinstance(query, str):
return query
elif isinstance(query, bytes):
return query.decode()
elif isinstance(query, (psycopg_sql.SQL, psycopg_sql.Composed)):
return query.as_string(cursor.connection)
else: # pragma: no cover
raise TypeError(f"Unsupported query type: {type(query)}")
def _is_concurrent_statement(sql: _Query, cursor: CursorWrapper) -> bool:
"""
True if the sql statement is concurrent and cannot be ran in a transaction
"""
sql = _query_to_str(sql, cursor)
sql = sql.strip().lower() if sql else ""
return sql.startswith("create") and "concurrently" in sql
def _is_transaction_errored(cursor):
"""
True if the current transaction is in an errored state
"""
if utils.psycopg_maj_version == 2:
return (
cursor.connection.get_transaction_status()
== psycopg2.extensions.TRANSACTION_STATUS_INERROR
)
elif utils.psycopg_maj_version == 3:
return cursor.connection.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
else:
raise AssertionError
def _can_inject_variable(cursor, sql):
"""True if we can inject a SQL variable into a statement.
A named cursor automatically prepends
"NO SCROLL CURSOR WITHOUT HOLD FOR" to the query, which
causes invalid SQL to be generated. There is no way
to override this behavior in psycopg, so ignoring triggers
cannot happen for named cursors. Django only names cursors
for iterators and other statements that read the database,
so it seems to be safe to ignore named cursors.
Concurrent index creation is also incompatible with local variable
setting. Ignore these cases for now.
"""
return (
not getattr(cursor, "name", None)
and not _is_concurrent_statement(sql, cursor)
and not _is_transaction_errored(cursor)
)
def _execute_wrapper(execute_result):
if utils.psycopg_maj_version == 3:
while execute_result is not None and execute_result.nextset():
pass
return execute_result
def _inject_pgtrigger_ignore(execute, sql, params, many, context):
"""
A connection execution wrapper that sets a pgtrigger.ignore
variable in the executed SQL. This lets other triggers know when
they should ignore execution
"""
if _can_inject_variable(context["cursor"], sql):
serialized_ignore = "{" + ",".join(_ignore.value) + "}"
sql = _query_to_str(sql, context["cursor"])
sql = f"SELECT set_config('pgtrigger.ignore', %s, true); {sql}"
params = [serialized_ignore, *(params or ())]
return _execute_wrapper(execute(sql, params, many, context))
@contextlib.contextmanager
def _set_ignore_session_state(database=None):
"""Starts a session where triggers can be ignored"""
connection = utils.connection(database)
if _inject_pgtrigger_ignore not in connection.execute_wrappers:
with connection.execute_wrapper(_inject_pgtrigger_ignore):
try:
yield
finally:
if connection.in_atomic_block:
# We've finished ignoring triggers and are in a transaction,
# so flush the local variable.
with connection.cursor() as cursor:
cursor.execute("SELECT set_config('pgtrigger.ignore', NULL, false);")
else:
yield
@contextlib.contextmanager
def _ignore_session(databases=None):
"""Starts a session where triggers can be ignored"""
with contextlib.ExitStack() as stack:
for database in utils.postgres_databases(databases):
stack.enter_context(_set_ignore_session_state(database=database))
yield
@contextlib.contextmanager
def _set_ignore_state(model, trigger):
"""
Manage state to ignore a single URI
"""
if not hasattr(_ignore, "value"):
_ignore.value = set()
pgid = trigger.get_pgid(model)
if pgid not in _ignore.value:
# In order to preserve backwards compatibiliy with older installations
# of the _pgtrigger_ignore func, we must set a full URI (old version)
# and trigger ID (new version).
# This will be removed in version 5
uri = f"{model._meta.db_table}:{pgid}"
try:
_ignore.value.add(uri)
_ignore.value.add(pgid)
yield
finally:
_ignore.value.remove(uri)
_ignore.value.remove(pgid)
else: # The trigger is already being ignored
yield
@contextlib.contextmanager
def ignore(*uris: str, databases: Union[List[str], None] = None) -> Generator[None]:
"""
Dynamically ignore registered triggers matching URIs from executing in
an individual thread.
If no URIs are provided, ignore all pgtriggers from executing in an
individual thread.
Args:
*uris: Trigger URIs to ignore. If none are provided, all
triggers will be ignored.
databases: The databases to use. If none, all postgres databases
will be used.
Example:
Ingore triggers in a context manager:
with pgtrigger.ignore("my_app.Model:trigger_name"):
# Do stuff while ignoring trigger
Example:
Ignore multiple triggers as a decorator:
@pgtrigger.ignore("my_app.Model:trigger_name", "my_app.Model:other_trigger")
def my_func():
# Do stuff while ignoring trigger
"""
with contextlib.ExitStack() as stack:
stack.enter_context(_ignore_session(databases=databases))
for model, trigger in registry.registered(*uris):
stack.enter_context(_set_ignore_state(model, trigger))
yield
ignore.session = _ignore_session
def _inject_schema(execute, sql, params, many, context):
"""
A connection execution wrapper that sets the schema
variable in the executed SQL.
"""
if _can_inject_variable(context["cursor"], sql) and _schema.value:
path = ", ".join(val if not val.startswith("$") else f'"{val}"' for val in _schema.value)
sql = f"SELECT set_config('search_path', %s, true); {sql}"
params = [path, *(params or ())]
return _execute_wrapper(execute(sql, params, many, context))
@contextlib.contextmanager
def _set_schema_session_state(database=None):
connection = utils.connection(database)
if _inject_schema not in connection.execute_wrappers:
if connection.in_atomic_block:
# If this is the first time we are setting the search path,
# register the pre_execute_hook and store a reference to the original
# search path. Note that we must use this approach because we cannot
# simply reset the search_path at the end. A user may have previously
# set it
with connection.cursor() as cursor:
cursor.execute("SELECT current_setting('search_path')")
initial_search_path = cursor.fetchall()[0][0]
with connection.execute_wrapper(_inject_schema):
try:
yield
finally:
if connection.in_atomic_block:
# We've finished modifying the search path and are in a transaction,
# so flush the local variable
with connection.cursor() as cursor:
cursor.execute(
"SELECT set_config('search_path', %s, false)", [initial_search_path]
)
else:
yield
@contextlib.contextmanager
def _schema_session(databases=None):
"""Starts a session where the search path can be modified"""
with contextlib.ExitStack() as stack:
for database in utils.postgres_databases(databases):
stack.enter_context(_set_schema_session_state(database=database))
yield
@contextlib.contextmanager
def _set_schema_state(*schemas):
if not hasattr(_schema, "value"):
# Use a list instead of a set because ordering is important to the search path
_schema.value = []
schemas = [s for s in schemas if s not in _schema.value]
try:
_schema.value.extend(schemas)
yield
finally:
for s in schemas:
_schema.value.remove(s)
@contextlib.contextmanager
def schema(*schemas: str, databases: Union[List[str], None] = None) -> Generator[None]:
"""
Sets the search path to the provided schemas.
If nested, appends the schemas to the search path if not already in it.
Args:
*schemas: Schemas that should be appended to the search path.
Schemas already in the search path from nested calls will not be
appended.
databases: The databases to set the search path. If none, all postgres
databases will be used.
"""
with contextlib.ExitStack() as stack:
stack.enter_context(_schema_session(databases=databases))
stack.enter_context(_set_schema_state(*schemas))
yield
schema.session = _schema_session
def constraints(timing: "Timing", *uris: str, databases: Union[List[str], None] = None) -> None:
"""
Set deferrable constraint timing for the given triggers, which
will persist until overridden or until end of transaction.
Must be in a transaction to run this.
Args:
timing: The timing value that overrides the default trigger timing.
*uris: Trigger URIs over which to set constraint timing.
If none are provided, all trigger constraint timing will
be set. All triggers must be deferrable.
databases: The databases on which to set constraints. If none, all
postgres databases will be used.
Raises:
RuntimeError: If the database of any triggers is not in a transaction.
ValueError: If any triggers are not deferrable.
"""
for model, trigger in registry.registered(*uris):
if not trigger.timing:
raise ValueError(
f"Trigger {trigger.name} on model {model._meta.label_lower} is not deferrable."
)
for database in utils.postgres_databases(databases):
if not connections[database].in_atomic_block:
raise RuntimeError(f'Database "{database}" is not in a transaction.')
names = ", ".join(trigger.get_pgid(model) for model, trigger in registry.registered(*uris))
with connections[database].cursor() as cursor:
cursor.execute(f"SET CONSTRAINTS {names} {timing}")
|