1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
from django.db import connections, models
from django.db.models.sql.query import Query
from tree_queries.compiler import SEPARATOR, TreeQuery
def pk(of):
"""
Returns the primary key of the argument if it is an instance of a model, or
the argument as-is otherwise
"""
return of.pk if hasattr(of, "pk") else of
class TreeManager(models.Manager):
def get_queryset(self):
queryset = super().get_queryset()
return queryset.with_tree_fields() if self._with_tree_fields else queryset
class TreeQuerySet(models.QuerySet):
def with_tree_fields(self, tree_fields=True): # noqa: FBT002
"""
Requests tree fields on this queryset
Pass ``False`` to revert to a queryset without tree fields.
"""
if tree_fields:
self.query.__class__ = TreeQuery
self.query._setup_query()
else:
self.query.__class__ = Query
return self
def without_tree_fields(self):
"""
Requests no tree fields on this queryset
"""
return self.with_tree_fields(tree_fields=False)
def order_siblings_by(self, *order_by):
"""
Sets TreeQuery sibling_order attribute
Pass the names of model fields as a list of strings
to order tree siblings by those model fields
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.sibling_order = order_by
return self
def tree_filter(self, *args, **kwargs):
"""
Adds a filter to the TreeQuery rank_table_query
Takes the same arguements as a Django QuerySet .filter()
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.rank_table_query = self.query.rank_table_query.filter(
*args, **kwargs
)
return self
def tree_exclude(self, *args, **kwargs):
"""
Adds a filter to the TreeQuery rank_table_query
Takes the same arguements as a Django QuerySet .exclude()
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.rank_table_query = self.query.rank_table_query.exclude(
*args, **kwargs
)
return self
def tree_fields(self, **tree_fields):
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.tree_fields = tree_fields
return self
@classmethod
def as_manager(cls, *, with_tree_fields=False):
manager_class = TreeManager.from_queryset(cls)
# Only used in deconstruct:
manager_class._built_with_as_manager = True
# Set attribute on class, not on the instance so that the automatic
# subclass generation used e.g. for relations also finds this
# attribute.
manager_class._with_tree_fields = with_tree_fields
return manager_class()
as_manager.queryset_only = True
def ancestors(self, of, *, include_self=False):
"""
Returns ancestors of the given node ordered from the root of the tree
towards deeper levels, optionally including the node itself
"""
if not hasattr(of, "tree_path"):
of = self.with_tree_fields().get(pk=pk(of))
ids = of.tree_path if include_self else of.tree_path[:-1]
return (
self.with_tree_fields() # TODO tree fields not strictly required
.filter(pk__in=ids)
.extra(order_by=["__tree.tree_depth"])
)
def descendants(self, of, *, include_self=False):
"""
Returns descendants of the given node in depth-first order, optionally
including and starting with the node itself
"""
connection = connections[self.db]
if connection.vendor == "postgresql":
queryset = self.with_tree_fields().extra(
where=["%s = ANY(__tree.tree_path)"],
params=[self.model._meta.pk.get_db_prep_value(pk(of), connection)],
)
else:
queryset = self.with_tree_fields().extra(
# NOTE! The representation of tree_path is NOT part of the API.
where=[
# XXX This *may* be unsafe with some primary key field types.
# It is certainly safe with integers.
f'instr(__tree.tree_path, "{SEPARATOR}{self.model._meta.pk.get_db_prep_value(pk(of), connection)}{SEPARATOR}") <> 0'
]
)
if not include_self:
return queryset.exclude(pk=pk(of))
return queryset
|