1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
|
import inspect
import os
import sys
from collections.abc import Iterable
from typing import TYPE_CHECKING, Tuple, Union, cast
import django
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db.models import Expression, Model, Q
from django.db.models.fields.related import RelatedField
from django.db.models.sql import compiler as django_compiler
from .expressions import HStoreValue
from .types import ConflictAction
if TYPE_CHECKING:
from .sql import PostgresInsertQuery
def append_caller_to_sql(sql):
"""Append the caller to SQL queries.
Adds the calling file and function as an SQL comment to each query.
Examples:
INSERT INTO "tests_47ee19d1" ("id", "title")
VALUES (1, 'Test')
RETURNING "tests_47ee19d1"."id"
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 55 */
SELECT "tests_47ee19d1"."id", "tests_47ee19d1"."title"
FROM "tests_47ee19d1"
WHERE "tests_47ee19d1"."id" = 1
LIMIT 1
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 69 */
UPDATE "tests_47ee19d1"
SET "title" = 'success'
WHERE "tests_47ee19d1"."id" = 1
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 64 */
DELETE FROM "tests_47ee19d1"
WHERE "tests_47ee19d1"."id" IN (1)
/* 998020 test_append_caller_to_sql_crud .../django-postgres-extra/tests/test_append_caller_to_sql.py 74 */
Slow and blocking queries could be easily tracked down to their originator
within the source code using the "pg_stat_activity" table.
Enable "POSTGRES_EXTRA_ANNOTATE_SQL" within the database settings to enable this feature.
"""
if not getattr(settings, "POSTGRES_EXTRA_ANNOTATE_SQL", None):
return sql
try:
# Search for the first non-Django caller
stack = inspect.stack()
for stack_frame in stack[1:]:
frame_filename = stack_frame[1]
frame_line = stack_frame[2]
frame_function = stack_frame[3]
if "/django/" in frame_filename or "/psqlextra/" in frame_filename:
continue
return f"{sql} /* {os.getpid()} {frame_function} {frame_filename} {frame_line} */"
# Django internal commands (like migrations) end up here
return f"{sql} /* {os.getpid()} {sys.argv[0]} */"
except Exception:
# Don't break anything because this convinence function runs into an unexpected situation
return sql
class SQLCompiler(django_compiler.SQLCompiler): # type: ignore [attr-defined]
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): # type: ignore [name-defined]
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): # type: ignore [name-defined]
def as_sql(self, *args, **kwargs):
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): # type: ignore [name-defined]
"""Compiler for SQL UPDATE statements that allows us to use expressions
inside HStore values.
Like:
.update(name=dict(en=F('test')))
"""
def as_sql(self, *args, **kwargs):
self._prepare_query_values()
sql, params = super().as_sql(*args, **kwargs)
return append_caller_to_sql(sql), params
def _prepare_query_values(self):
"""Extra prep on query values by converting dictionaries into
:see:HStoreValue expressions.
This allows putting expressions in a dictionary. The
:see:HStoreValue will take care of resolving the expressions
inside the dictionary.
"""
if not self.query.values:
return
new_query_values = []
for field, model, val in self.query.values:
if not isinstance(val, dict):
new_query_values.append((field, model, val))
continue
if not self._does_dict_contain_expression(val):
new_query_values.append((field, model, val))
continue
expression = HStoreValue(dict(val))
new_query_values.append((field, model, expression))
self.query.values = new_query_values
@staticmethod
def _does_dict_contain_expression(data: dict) -> bool:
"""Gets whether the specified dictionary contains any expressions that
need to be resolved."""
for value in data.values():
if hasattr(value, "resolve_expression"):
return True
if hasattr(value, "as_sql"):
return True
return False
class SQLInsertCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
"""Compiler for SQL INSERT statements."""
def as_sql(self, *args, **kwargs):
"""Builds the SQL INSERT statement."""
queries = [
(append_caller_to_sql(sql), params)
for sql, params in super().as_sql(*args, **kwargs)
]
return queries
class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
"""Compiler for SQL INSERT statements."""
query: "PostgresInsertQuery"
def __init__(self, *args, **kwargs):
"""Initializes a new instance of
:see:PostgresInsertOnConflictCompiler."""
super().__init__(*args, **kwargs)
self.qn = self.connection.ops.quote_name
def as_sql(self, return_id=False, *args, **kwargs):
"""Builds the SQL INSERT statement."""
queries = [
self._rewrite_insert(sql, params, return_id)
for sql, params in super().as_sql(*args, **kwargs)
]
return queries
def _rewrite_insert(self, sql, params, return_id=False):
"""Rewrites a formed SQL INSERT query to include the ON CONFLICT
clause.
Arguments:
sql:
The SQL INSERT query to rewrite.
params:
The parameters passed to the query.
return_id:
Whether to only return the ID or all
columns.
Returns:
A tuple of the rewritten SQL query and new params.
"""
returning = (
self.qn(self.query.model._meta.pk.attname) if return_id else "*"
)
(sql, params) = self._rewrite_insert_on_conflict(
sql, params, self.query.conflict_action.value, returning
)
return append_caller_to_sql(sql), params
def _rewrite_insert_on_conflict(
self, sql, params, conflict_action: ConflictAction, returning
):
"""Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
clause."""
# build the conflict target, the columns to watch
# for conflicts
on_conflict_clause = self._build_on_conflict_clause()
index_predicate = self.query.index_predicate # type: ignore[attr-defined]
update_condition = self.query.conflict_update_condition # type: ignore[attr-defined]
rewritten_sql = f"{sql} {on_conflict_clause}"
if index_predicate:
expr_sql, expr_params = self._compile_expression(index_predicate)
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)
# Fallback in case the user didn't specify any update values. We can still
# make the query work if we switch to ConflictAction.NOTHING
if (
conflict_action == ConflictAction.UPDATE.value
and not self.query.update_values
):
conflict_action = ConflictAction.NOTHING
rewritten_sql += f" DO {conflict_action}"
if conflict_action == ConflictAction.UPDATE.value:
set_sql, sql_params = self._build_set_statement()
rewritten_sql += f" SET {set_sql}"
params += sql_params
if update_condition:
expr_sql, expr_params = self._compile_expression(
update_condition
)
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)
rewritten_sql += f" RETURNING {returning}"
return (rewritten_sql, params)
def _build_set_statement(self) -> Tuple[str, tuple]:
"""Builds the SET statement for the ON CONFLICT DO UPDATE clause.
This uses the update compiler to provide full compatibility with
the standard Django's `update(...)`.
"""
# Local import to work around the circular dependency between
# the compiler and the queries.
from .sql import PostgresUpdateQuery
query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery))
query.add_update_values(self.query.update_values)
sql, params = query.get_compiler(self.connection.alias).as_sql()
return sql.split("SET")[1].split(" WHERE")[0], tuple(params)
def _build_on_conflict_clause(self):
if django.VERSION >= (2, 2):
from django.db.models.constraints import BaseConstraint
from django.db.models.indexes import Index
if isinstance(
self.query.conflict_target, BaseConstraint
) or isinstance(self.query.conflict_target, Index):
return "ON CONFLICT ON CONSTRAINT %s" % self.qn(
self.query.conflict_target.name
)
conflict_target = self._build_conflict_target()
return f"ON CONFLICT {conflict_target}"
def _build_conflict_target(self):
"""Builds the `conflict_target` for the ON CONFLICT clause."""
if not isinstance(self.query.conflict_target, Iterable):
raise SuspiciousOperation(
(
"%s is not a valid conflict target, specify "
"a list of column names, or tuples with column "
"names and hstore key."
)
% str(self.query.conflict_target)
)
conflict_target = self._build_conflict_target_by_index()
if conflict_target:
return conflict_target
return self._build_conflict_target_by_fields()
def _build_conflict_target_by_fields(self):
"""Builds the `conflict_target` for the ON CONFLICT clauses by matching
the fields specified in the specified conflict target against the
model's fields.
This requires some special handling because the fields names
might not be same as the column names.
"""
conflict_target = []
for field_name in self.query.conflict_target:
self._assert_valid_field(field_name)
# special handling for hstore keys
if isinstance(field_name, tuple):
conflict_target.append(
"(%s->'%s')"
% (self._format_field_name(field_name), field_name[1])
)
else:
conflict_target.append(self._format_field_name(field_name))
return "(%s)" % ",".join(conflict_target)
def _build_conflict_target_by_index(self):
"""Builds the `conflict_target` for the ON CONFLICT clause by trying to
find an index that matches the specified conflict target on the query.
Conflict targets must match some unique constraint, usually this
is a `UNIQUE INDEX`.
"""
matching_index = next(
(
index
for index in self.query.model._meta.indexes
if list(index.fields) == list(self.query.conflict_target)
),
None,
)
if not matching_index:
return None
with self.connection.schema_editor() as schema_editor:
stmt = matching_index.create_sql(self.query.model, schema_editor)
return "(%s)" % stmt.parts["columns"]
def _get_model_field(self, name: str):
"""Gets the field on a model with the specified name.
Arguments:
name:
The name of the field to look for.
This can be both the actual field name, or
the name of the column, both will work :)
Returns:
The field with the specified name or None if
no such field exists.
"""
field_name = self._normalize_field_name(name)
if not self.query.model:
return None
# 'pk' has special meaning and always refers to the primary
# key of a model, we have to respect this de-facto standard behaviour
if field_name == "pk" and self.query.model._meta.pk:
return self.query.model._meta.pk
for field in self.query.model._meta.local_concrete_fields: # type: ignore[attr-defined]
if field.name == field_name or field.column == field_name:
return field
return None
def _format_field_name(self, field_name) -> str:
"""Formats a field's name for usage in SQL.
Arguments:
field_name:
The field name to format.
Returns:
The specified field name formatted for
usage in SQL.
"""
field = self._get_model_field(field_name)
return self.qn(field.column)
def _format_field_value(self, field_name) -> str:
"""Formats a field's value for usage in SQL.
Arguments:
field_name:
The name of the field to format
the value of.
Returns:
The field's value formatted for usage
in SQL.
"""
field_name = self._normalize_field_name(field_name)
field = self._get_model_field(field_name)
value = getattr(self.query.objs[0], field.attname)
if isinstance(field, RelatedField) and isinstance(value, Model):
value = value.pk
return django_compiler.SQLInsertCompiler.prepare_value( # type: ignore[attr-defined]
self,
field,
# Note: this deliberately doesn't use `pre_save_val` as we don't
# want things like auto_now on DateTimeField (etc.) to change the
# value. We rely on pre_save having already been done by the
# underlying compiler so that things like FileField have already had
# the opportunity to save out their data.
value,
)
def _compile_expression(
self, expression: Union[Expression, Q, str]
) -> Tuple[str, Union[tuple, list]]:
"""Compiles an expression, Q object or raw SQL string into SQL and
tuple of parameters."""
if isinstance(expression, Q):
if django.VERSION < (3, 1):
raise SuspiciousOperation(
"Q objects in psqlextra can only be used with Django 3.1 and newer"
)
return self.query.build_where(expression).as_sql(
self, self.connection
)
elif isinstance(expression, Expression):
return self.compile(expression)
return expression, tuple()
def _assert_valid_field(self, field_name: str):
"""Asserts that a field with the specified name exists on the model and
raises :see:SuspiciousOperation if it does not."""
field_name = self._normalize_field_name(field_name)
if self._get_model_field(field_name):
return
raise SuspiciousOperation(
(
"%s is not a valid conflict target, specify "
"a list of column names, or tuples with column "
"names and hstore key."
)
% str(field_name)
)
@staticmethod
def _normalize_field_name(field_name: str) -> str:
"""Normalizes a field name into a string by extracting the field name
if it was specified as a reference to a HStore key (as a tuple).
Arguments:
field_name:
The field name to normalize.
Returns:
The normalized field name.
"""
if isinstance(field_name, tuple):
field_name, _ = field_name
return field_name
|