1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
from functools import partial
from operator import methodcaller
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db.models.query import ModelIterable
type_cast_iterator = partial(map, methodcaller("type_cast"))
type_cast_prefetch_iterator = partial(
map, methodcaller("type_cast", with_prefetched_objects=True)
)
class PolymorphicModelIterable(ModelIterable):
def __init__(self, queryset, type_cast=True, **kwargs):
self.type_cast = type_cast
super().__init__(queryset, **kwargs)
def __iter__(self):
iterator = super().__iter__()
if self.type_cast:
iterator = type_cast_iterator(iterator)
return iterator
class PolymorphicQuerySet(models.query.QuerySet):
def select_subclasses(self, *models):
if issubclass(self._iterable_class, ModelIterable):
self._iterable_class = PolymorphicModelIterable
related_lookups = set()
accessors = self.model.subclass_accessors
if models:
subclasses = set()
for model in models:
if not issubclass(model, self.model):
raise TypeError("%r is not a subclass of %r" % (model, self.model))
subclasses.update(model.subclass_accessors)
# Collect all `select_related` required lookups
for subclass in subclasses:
# Avoid collecting ourself and proxy subclasses
related_lookup = accessors[subclass].related_lookup
if related_lookup:
related_lookups.add(related_lookup)
queryset = self.filter(**self.model.content_type_lookup(*tuple(subclasses)))
else:
# Collect all `select_related` required relateds
for accessor in accessors.values():
# Avoid collecting ourself and proxy subclasses
related_lookup = accessor.related_lookup
if related_lookup:
related_lookups.add(related_lookup)
queryset = self
if related_lookups:
queryset = queryset.select_related(*related_lookups)
return queryset
def exclude_subclasses(self):
return self.filter(**self.model.content_type_lookup())
def _fetch_all(self):
# Override _fetch_all in order to disable PolymorphicModelIterable's
# type casting when prefetch_related is used because the latter might
# crash or disfunction when dealing with a mixed set of objects.
prefetch_related_objects = (
self._prefetch_related_lookups and not self._prefetch_done
)
type_cast = False
if self._result_cache is None:
iterable_class = self._iterable_class
if issubclass(iterable_class, PolymorphicModelIterable):
type_cast = bool(prefetch_related_objects)
iterable_class = partial(iterable_class, type_cast=not type_cast)
self._result_cache = list(iterable_class(self))
if prefetch_related_objects:
self._prefetch_related_objects()
if type_cast:
self._result_cache = list(
type_cast_prefetch_iterator(self._result_cache)
)
class PolymorphicManager(models.Manager.from_queryset(PolymorphicQuerySet)):
def contribute_to_class(self, model, name):
# Avoid circular reference
from .models import BasePolymorphicModel
if not issubclass(model, BasePolymorphicModel):
raise ImproperlyConfigured(
"`%s` can only be used on "
"`BasePolymorphicModel` subclasses." % self.__class__.__name__
)
return super().contribute_to_class(model, name)
def get_queryset(self):
queryset = super().get_queryset()
model = self.model
if model._meta.proxy:
# Select only associated model and its subclasses.
queryset = queryset.filter(**self.model.subclasses_lookup())
return queryset
|