File: query.py

package info (click to toggle)
python-django-tree-queries 0.20-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 388 kB
  • sloc: python: 2,269; makefile: 26; sh: 6
file content (137 lines) | stat: -rw-r--r-- 4,633 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from django.db import connections, models
from django.db.models.sql.query import Query

from tree_queries.compiler import SEPARATOR, TreeQuery


def pk(of):
    """
    Returns the primary key of the argument if it is an instance of a model, or
    the argument as-is otherwise
    """
    return of.pk if hasattr(of, "pk") else of


class TreeManager(models.Manager):
    def get_queryset(self):
        queryset = super().get_queryset()
        return queryset.with_tree_fields() if self._with_tree_fields else queryset


class TreeQuerySet(models.QuerySet):
    def with_tree_fields(self, tree_fields=True):  # noqa: FBT002
        """
        Requests tree fields on this queryset

        Pass ``False`` to revert to a queryset without tree fields.
        """
        if tree_fields:
            self.query.__class__ = TreeQuery
            self.query._setup_query()
        else:
            self.query.__class__ = Query
        return self

    def without_tree_fields(self):
        """
        Requests no tree fields on this queryset
        """
        return self.with_tree_fields(tree_fields=False)

    def order_siblings_by(self, *order_by):
        """
        Sets TreeQuery sibling_order attribute

        Pass the names of model fields as a list of strings
        to order tree siblings by those model fields
        """
        self.query.__class__ = TreeQuery
        self.query._setup_query()
        self.query.sibling_order = order_by
        return self

    def tree_filter(self, *args, **kwargs):
        """
        Adds a filter to the TreeQuery rank_table_query

        Takes the same arguements as a Django QuerySet .filter()
        """
        self.query.__class__ = TreeQuery
        self.query._setup_query()
        self.query.rank_table_query = self.query.rank_table_query.filter(
            *args, **kwargs
        )
        return self

    def tree_exclude(self, *args, **kwargs):
        """
        Adds a filter to the TreeQuery rank_table_query

        Takes the same arguements as a Django QuerySet .exclude()
        """
        self.query.__class__ = TreeQuery
        self.query._setup_query()
        self.query.rank_table_query = self.query.rank_table_query.exclude(
            *args, **kwargs
        )
        return self

    def tree_fields(self, **tree_fields):
        self.query.__class__ = TreeQuery
        self.query._setup_query()
        self.query.tree_fields = tree_fields
        return self

    @classmethod
    def as_manager(cls, *, with_tree_fields=False):
        manager_class = TreeManager.from_queryset(cls)
        # Only used in deconstruct:
        manager_class._built_with_as_manager = True
        # Set attribute on class, not on the instance so that the automatic
        # subclass generation used e.g. for relations also finds this
        # attribute.
        manager_class._with_tree_fields = with_tree_fields
        return manager_class()

    as_manager.queryset_only = True

    def ancestors(self, of, *, include_self=False):
        """
        Returns ancestors of the given node ordered from the root of the tree
        towards deeper levels, optionally including the node itself
        """
        if not hasattr(of, "tree_path"):
            of = self.with_tree_fields().get(pk=pk(of))

        ids = of.tree_path if include_self else of.tree_path[:-1]
        return (
            self.with_tree_fields()  # TODO tree fields not strictly required
            .filter(pk__in=ids)
            .extra(order_by=["__tree.tree_depth"])
        )

    def descendants(self, of, *, include_self=False):
        """
        Returns descendants of the given node in depth-first order, optionally
        including and starting with the node itself
        """
        connection = connections[self.db]
        if connection.vendor == "postgresql":
            queryset = self.with_tree_fields().extra(
                where=["%s = ANY(__tree.tree_path)"],
                params=[self.model._meta.pk.get_db_prep_value(pk(of), connection)],
            )

        else:
            queryset = self.with_tree_fields().extra(
                # NOTE! The representation of tree_path is NOT part of the API.
                where=[
                    # XXX This *may* be unsafe with some primary key field types.
                    # It is certainly safe with integers.
                    f'instr(__tree.tree_path, "{SEPARATOR}{self.model._meta.pk.get_db_prep_value(pk(of), connection)}{SEPARATOR}") <> 0'
                ]
            )

        if not include_self:
            return queryset.exclude(pk=pk(of))
        return queryset