1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
from contextlib import contextmanager
from typing import Generator, List, Optional, Union
from django.core.exceptions import SuspiciousOperation
from django.db import DEFAULT_DB_ALIAS, connections
@contextmanager
def postgres_set_local(
*,
using: str = DEFAULT_DB_ALIAS,
**options: Optional[Union[str, int, float, List[str]]],
) -> Generator[None, None, None]:
"""Sets the specified PostgreSQL options using SET LOCAL so that they apply
to the current transacton only.
The effect is undone when the context manager exits.
See https://www.postgresql.org/docs/current/runtime-config-client.html
for an overview of all available options.
"""
connection = connections[using]
qn = connection.ops.quote_name
if not connection.in_atomic_block:
raise SuspiciousOperation(
"SET LOCAL makes no sense outside a transaction. Start a transaction first."
)
sql = []
params: List[Union[str, int, float, List[str]]] = []
for name, value in options.items():
if value is None:
sql.append(f"SET LOCAL {qn(name)} TO DEFAULT")
continue
# Settings that accept a list of values are actually
# stored as string lists. We cannot just pass a list
# of values. We have to create the comma separated
# string ourselves.
if isinstance(value, list) or isinstance(value, tuple):
placeholder = ", ".join(["%s" for _ in value])
params.extend(value)
else:
placeholder = "%s"
params.append(value)
sql.append(f"SET LOCAL {qn(name)} = {placeholder}")
with connection.cursor() as cursor:
cursor.execute(
"SELECT name, setting FROM pg_settings WHERE name = ANY(%s)",
(list(options.keys()),),
)
original_values = dict(cursor.fetchall())
cursor.execute("; ".join(sql), params)
yield
# Put everything back to how it was. DEFAULT is
# not good enough as a outer SET LOCAL might
# have set a different value.
with connection.cursor() as cursor:
sql = []
params = []
for name, value in options.items():
original_value = original_values.get(name)
if original_value:
sql.append(f"SET LOCAL {qn(name)} = {original_value}")
else:
sql.append(f"SET LOCAL {qn(name)} TO DEFAULT")
cursor.execute("; ".join(sql), params)
@contextmanager
def postgres_set_local_search_path(
search_path: List[str], *, using: str = DEFAULT_DB_ALIAS
) -> Generator[None, None, None]:
"""Sets the search path to the specified schemas."""
with postgres_set_local(search_path=search_path, using=using):
yield
@contextmanager
def postgres_prepend_local_search_path(
search_path: List[str], *, using: str = DEFAULT_DB_ALIAS
) -> Generator[None, None, None]:
"""Prepends the current local search path with the specified schemas."""
connection = connections[using]
with connection.cursor() as cursor:
cursor.execute("SHOW search_path")
[
original_search_path,
] = cursor.fetchone()
placeholders = ", ".join(["%s" for _ in search_path])
cursor.execute(
f"SET LOCAL search_path = {placeholders}, {original_search_path}",
tuple(search_path),
)
yield
cursor.execute(f"SET LOCAL search_path = {original_search_path}")
@contextmanager
def postgres_reset_local_search_path(
*, using: str = DEFAULT_DB_ALIAS
) -> Generator[None, None, None]:
"""Resets the local search path to the default."""
with postgres_set_local(search_path=None, using=using):
yield
|