1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
|
from __future__ import unicode_literals
import django
from django.db import models
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import QuerySet
from django.core.exceptions import ObjectDoesNotExist
try:
from django.db.models.constants import LOOKUP_SEP
from django.utils.six import string_types
except ImportError: # Django < 1.5
from django.db.models.sql.constants import LOOKUP_SEP
string_types = (basestring,)
class InheritanceQuerySetMixin(object):
def select_subclasses(self, *subclasses):
levels = self._get_maximum_depth()
calculated_subclasses = self._get_subclasses_recurse(
self.model, levels=levels)
# if none were passed in, we can just short circuit and select all
if not subclasses:
subclasses = calculated_subclasses
else:
verified_subclasses = []
for subclass in subclasses:
# special case for passing in the same model as the queryset
# is bound against. Rather than raise an error later, we know
# we can allow this through.
if subclass is self.model:
continue
if not isinstance(subclass, string_types):
subclass = self._get_ancestors_path(
subclass, levels=levels)
if subclass in calculated_subclasses:
verified_subclasses.append(subclass)
else:
raise ValueError(
'%r is not in the discovered subclasses, tried: %s' % (
subclass, ', '.join(calculated_subclasses))
)
subclasses = verified_subclasses
# workaround https://code.djangoproject.com/ticket/16855
previous_select_related = self.query.select_related
new_qs = self.select_related(*subclasses)
previous_is_dict = isinstance(previous_select_related, dict)
new_is_dict = isinstance(new_qs.query.select_related, dict)
if previous_is_dict and new_is_dict:
new_qs.query.select_related.update(previous_select_related)
new_qs.subclasses = subclasses
return new_qs
def _clone(self, klass=None, setup=False, **kwargs):
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
kwargs[name] = getattr(self, name)
if django.VERSION < (1, 9):
kwargs['klass'] = klass
kwargs['setup'] = setup
return super(InheritanceQuerySetMixin, self)._clone(**kwargs)
def annotate(self, *args, **kwargs):
qset = super(InheritanceQuerySetMixin, self).annotate(*args, **kwargs)
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
return qset
def iterator(self):
iter = super(InheritanceQuerySetMixin, self).iterator()
if getattr(self, 'subclasses', False):
extras = tuple(self.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(self.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
sub_obj = self._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj
if getattr(self, '_annotated', False):
for k in self._annotated:
setattr(sub_obj, k, getattr(obj, k))
for k in extras:
setattr(sub_obj, k, getattr(obj, k))
yield sub_obj
else:
for obj in iter:
yield obj
def _get_subclasses_recurse(self, model, levels=None):
"""
Given a Model class, find all related objects, exploring children
recursively, returning a `list` of strings representing the
relations for select_related
"""
if django.VERSION < (1, 8):
related_objects = model._meta.get_all_related_objects()
else:
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
rels = [
rel for rel in related_objects
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, model)
and model is not rel.field.model
]
subclasses = []
if levels:
levels -= 1
for rel in rels:
if levels or levels is None:
for subclass in self._get_subclasses_recurse(
rel.field.model, levels=levels):
subclasses.append(
rel.get_accessor_name() + LOOKUP_SEP + subclass)
subclasses.append(rel.get_accessor_name())
return subclasses
def _get_ancestors_path(self, model, levels=None):
"""
Serves as an opposite to _get_subclasses_recurse, instead walking from
the Model class up the Model's ancestry and constructing the desired
select_related string backwards.
"""
if not issubclass(model, self.model):
raise ValueError(
"%r is not a subclass of %r" % (model, self.model))
ancestry = []
# should be a OneToOneField or None
parent_link = model._meta.get_ancestor_link(self.model)
if levels:
levels -= 1
while parent_link is not None:
if django.VERSION < (1, 8):
related = parent_link.related
else:
related = parent_link.rel
ancestry.insert(0, related.get_accessor_name())
if levels or levels is None:
if django.VERSION < (1, 8):
parent_model = related.parent_model
else:
parent_model = related.model
parent_link = parent_model._meta.get_ancestor_link(
self.model)
else:
parent_link = None
return LOOKUP_SEP.join(ancestry)
def _get_sub_obj_recurse(self, obj, s):
rel, _, s = s.partition(LOOKUP_SEP)
# Django 1.9: If a primitive type gets passed to this recursive function,
# return None as non-models are not part of inheritance.
if not isinstance(obj, models.Model):
return None
try:
node = getattr(obj, rel)
except ObjectDoesNotExist:
return None
if s:
child = self._get_sub_obj_recurse(node, s)
return child
else:
return node
def get_subclass(self, *args, **kwargs):
return self.select_subclasses().get(*args, **kwargs)
def _get_maximum_depth(self):
"""
Under Django versions < 1.6, to avoid triggering
https://code.djangoproject.com/ticket/16572 we can only look
as far as children.
"""
levels = None
if django.VERSION < (1, 6, 0):
levels = 1
return levels
class InheritanceManagerMixin(object):
use_for_related_fields = True
def get_queryset(self):
return InheritanceQuerySet(self.model)
get_query_set = get_queryset
def select_subclasses(self, *subclasses):
return self.get_queryset().select_subclasses(*subclasses)
def get_subclass(self, *args, **kwargs):
return self.get_queryset().get_subclass(*args, **kwargs)
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
pass
class InheritanceManager(InheritanceManagerMixin, models.Manager):
pass
class QueryManagerMixin(object):
use_for_related_fields = True
def __init__(self, *args, **kwargs):
if args:
self._q = args[0]
else:
self._q = models.Q(**kwargs)
self._order_by = None
super(QueryManagerMixin, self).__init__()
def order_by(self, *args):
self._order_by = args
return self
def get_queryset(self):
try:
qs = super(QueryManagerMixin, self).get_queryset().filter(self._q)
except AttributeError:
qs = super(QueryManagerMixin, self).get_query_set().filter(self._q)
if self._order_by is not None:
return qs.order_by(*self._order_by)
return qs
get_query_set = get_queryset
class QueryManager(QueryManagerMixin, models.Manager):
pass
|