1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
|
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
import django
from django.core.exceptions import SuspiciousOperation
from django.db import connections, models
from django.db.models import Expression, sql
from django.db.models.constants import LOOKUP_SEP
from .compiler import PostgresInsertOnConflictCompiler
from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler
from .expressions import HStoreColumn
from .fields import HStoreField
from .types import ConflictAction
class PostgresQuery(sql.Query):
select: Tuple[Expression, ...]
def chain(self, klass=None):
"""Chains this query to another.
We override this so that we can make sure our subclassed query
classes are used.
"""
if klass == sql.UpdateQuery:
return super().chain(PostgresUpdateQuery)
if klass == sql.InsertQuery:
return super().chain(PostgresInsertQuery)
return super().chain(klass)
def rename_annotations(self, annotations) -> None:
"""Renames the aliases for the specified annotations:
.annotate(myfield=F('somestuf__myfield'))
.rename_annotations(myfield='field')
Arguments:
annotations:
The annotations to rename. Mapping the
old name to the new name.
"""
# safety check only, make sure there are no renames
# left that cannot be mapped back to the original name
for old_name, new_name in annotations.items():
annotation = self.annotations.get(old_name)
if not annotation:
raise SuspiciousOperation(
(
'Cannot rename annotation "{old_name}" to "{new_name}", because there'
' is no annotation named "{old_name}".'
).format(old_name=old_name, new_name=new_name)
)
# rebuild the annotations according to the original order
new_annotations = OrderedDict()
for old_name, annotation in self.annotations.items():
new_name = annotations.get(old_name)
new_annotations[new_name or old_name] = annotation
if new_name and self.annotation_select_mask:
# It's a set in all versions prior to Django 5.x
# and a list in Django 5.x and newer.
# https://github.com/django/django/commit/d6b6e5d0fd4e6b6d0183b4cf6e4bd4f9afc7bf67
if isinstance(self.annotation_select_mask, set):
self.annotation_select_mask.discard(old_name)
self.annotation_select_mask.add(new_name)
elif isinstance(self.annotation_select_mask, list):
self.annotation_select_mask.remove(old_name)
self.annotation_select_mask.append(new_name)
self.annotations.clear()
self.annotations.update(new_annotations)
def add_fields(self, field_names, *args, **kwargs) -> None:
"""Adds the given (model) fields to the select set.
The field names are added in the order specified. This overrides
the base class's add_fields method. This is called by the
.values() or .values_list() method of the query set. It
instructs the ORM to only select certain values. A lot of
processing is neccesarry because it can be used to easily do
joins. For example, `my_fk__name` pulls in the `name` field in
foreign key `my_fk`. In our case, we want to be able to do
`title__en`, where `title` is a HStoreField and `en` a key. This
doesn't really involve a join. We iterate over the specified
field names and filter out the ones that refer to HStoreField
and compile it into an expression which is added to the list of
to be selected fields using `self.add_select`.
"""
# django knows how to do all of this natively from v2.1
# see: https://github.com/django/django/commit/20bab2cf9d02a5c6477d8aac066a635986e0d3f3
if django.VERSION >= (2, 1):
return super().add_fields(field_names, *args, **kwargs)
select = []
field_names_without_hstore = []
for name in field_names:
parts = name.split(LOOKUP_SEP)
# it cannot be a special hstore thing if there's no __ in it
if len(parts) > 1:
column_name, hstore_key = parts[:2]
is_hstore, field = self._is_hstore_field(column_name)
if self.model and is_hstore:
select.append(
HStoreColumn(
self.model._meta.db_table
or self.model.__class__.__name__,
field,
hstore_key,
)
)
continue
field_names_without_hstore.append(name)
super().add_fields(field_names_without_hstore, *args, **kwargs)
if len(select) > 0:
self.set_select(list(self.select + tuple(select)))
def _is_hstore_field(
self, field_name: str
) -> Tuple[bool, Optional[models.Field]]:
"""Gets whether the field with the specified name is a HStoreField.
Returns A tuple of a boolean indicating whether the field
with the specified name is a HStoreField, and the field
instance.
"""
if not self.model:
return (False, None)
field_instance = None
for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined]
if field.name == field_name or field.column == field_name:
field_instance = field
break
return isinstance(field_instance, HStoreField), field_instance
class PostgresInsertQuery(sql.InsertQuery):
"""Insert query using PostgreSQL."""
def __init__(self, *args, **kwargs):
"""Initializes a new instance :see:PostgresInsertQuery."""
super(PostgresInsertQuery, self).__init__(*args, **kwargs)
self.conflict_target = []
self.conflict_action = ConflictAction.UPDATE
self.conflict_update_condition = None
self.index_predicate = None
self.update_values = {}
def insert_on_conflict_values(
self,
objs: List,
insert_fields: List,
update_values: Dict[str, Union[Any, Expression]] = {},
):
"""Sets the values to be used in this query.
Insert fields are fields that are definitely
going to be inserted, and if an existing row
is found, are going to be overwritten with the
specified value.
Update fields are fields that should be overwritten
in case an update takes place rather than an insert.
If we're dealing with a INSERT, these will not be used.
Arguments:
objs:
The objects to apply this query to.
insert_fields:
The fields to use in the INSERT statement
update_values:
Expressions/values to use when a conflict
occurs and an UPDATE is performed.
"""
self.insert_values(insert_fields, objs, raw=False)
self.update_values = update_values
def get_compiler(self, using=None, connection=None):
if using:
connection = connections[using]
return PostgresInsertOnConflictCompiler(self, connection, using)
class PostgresUpdateQuery(sql.UpdateQuery):
"""Update query using PostgreSQL."""
def get_compiler(self, using=None, connection=None):
if using:
connection = connections[using]
return PostgresUpdateCompiler(self, connection, using)
|