1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
from django.db.models.lookups import (
Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
LessThanOrEqual,
)
class MultiColSource:
contains_aggregate = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias),
self.targets, self.sources, self.field)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def get_normalized_value(value, lhs):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
sources = lhs.output_field.get_path_info()[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.remote_field:
source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
try:
value_list.append(getattr(value, source.attname))
except AttributeError:
# A case like Restaurant.objects.filter(place=restaurant_instance),
# where place is a OneToOneField and the primary key of Restaurant.
return (value.pk,)
return tuple(value_list)
if not isinstance(value, tuple):
return (value,)
return value
class RelatedIn(In):
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource):
if self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
# We need to run the related field's get_prep_value(). Consider
# case ForeignKey to IntegerField given value 'abc'. The
# ForeignKey itself doesn't have validation for non-integers,
# so we must run validation using the target field.
if hasattr(self.lhs.output_field, 'get_path_infos'):
# Run the target field's get_prep_value. We can safely
# assume there is only one as we don't get to the direct
# value branch otherwise.
target_field = self.lhs.output_field.get_path_infos()[-1].target_fields[-1]
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
elif (
not getattr(self.rhs, 'has_select_fields', True) and
not getattr(self.lhs.field.target_field, 'primary_key', False)
):
self.rhs.clear_select_clause()
if (
getattr(self.lhs.output_field, 'primary_key', False) and
self.lhs.output_field.model == self.rhs.model
):
# A case like
# Restaurant.objects.filter(place__in=restaurant_qs), where
# place is a OneToOneField and the primary key of
# Restaurant.
target_field = self.lhs.field.name
else:
target_field = self.lhs.field.target_field.name
self.rhs.add_fields([target_field], True)
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need to be compiled to
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
AND, OR, SubqueryConstraint, WhereNode,
)
root_constraint = WhereNode(connector=OR)
if self.rhs_is_direct_value():
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
lookup_class = target.get_lookup('exact')
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
self.lhs.alias, [target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources], self.rhs),
AND)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedLookupMixin:
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = target_field.get_prep_value(self.rhs)
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
root_constraint = WhereNode()
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedExact(RelatedLookupMixin, Exact):
pass
class RelatedLessThan(RelatedLookupMixin, LessThan):
pass
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
pass
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
pass
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
pass
class RelatedIsNull(RelatedLookupMixin, IsNull):
pass
|