File: utils.py

package info (click to toggle)
python-django-modelcluster 6.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 508 kB
  • sloc: python: 5,026; sh: 6; makefile: 5
file content (216 lines) | stat: -rw-r--r-- 8,125 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import datetime
from functools import lru_cache
import random
from django.core.exceptions import FieldDoesNotExist
from django.db.models import (
    DateField,
    DateTimeField,
    ManyToManyField,
    ManyToManyRel,
    Model,
    TimeField,
)

from modelcluster import datetime_utils


REL_DELIMETER = "__"


class ManyToManyTraversalError(ValueError):
    pass


class NullRelationshipValueEncountered(Exception):
    pass


class TraversedRelationship:
    __slots__ = ['from_model', 'field']

    def __init__(self, from_model, field):
        self.from_model = from_model
        self.field = field

    @property
    def field_name(self) -> str:
        return self.field.name

    @property
    def to_model(self):
        return self.field.target_model


@lru_cache(maxsize=None)
def get_model_field(model, name):
    """
    Returns a model field matching the supplied ``name``, which can include
    double-underscores (`'__'`) to indicate relationship traversal - in which
    case, the model field will be lookuped up from the related model.

    Multiple traversals for the same field are supported, but at this
    moment in time, only traversal of many-to-one and one-to-one relationships
    is supported.

    Details of any relationships traversed in order to reach the returned
    field are made available as `field.traversals`. The value is a tuple of
    ``TraversedRelationship`` instances.

    Raises ``FieldDoesNotExist`` if the name cannot be mapped to a model field.
    """
    subject_model = model
    traversals = []
    field = None
    for field_name in name.split(REL_DELIMETER):

        if field is not None:
            if isinstance(field, (ManyToManyField, ManyToManyRel)):
                raise ManyToManyTraversalError(
                    "The lookup '{name}' from {model} cannot be replicated "
                    "by modelcluster, because the '{field_name}' "
                    "relationship from {subject_model} is a many-to-many, "
                    "and traversal is only supported for one-to-one or "
                    "many-to-one relationships."
                    .format(
                        name=name,
                        model=model,
                        field_name=field_name,
                        subject_model=subject_model,
                    )
                )
            elif getattr(field, "related_model", None):
                traversals.append(TraversedRelationship(subject_model, field))
                subject_model = field.related_model
            elif (
                (
                    isinstance(field, DateTimeField)
                    and field_name in datetime_utils.DATETIMEFIELD_TRANSFORM_EXPRESSIONS
                ) or (
                    isinstance(field, DateField)
                    and field_name in datetime_utils.DATEFIELD_TRANSFORM_EXPRESSIONS
                ) or (
                    isinstance(field, TimeField)
                    and field_name in datetime_utils.TIMEFIELD_TRANSFORM_EXPRESSIONS
                )
            ):
                transform_field_type = datetime_utils.TRANSFORM_FIELD_TYPES[field_name]
                field = transform_field_type()
                break
            else:
                raise FieldDoesNotExist(
                    "Failed attempting to traverse from {from_field} (a {from_field_type}) to '{to_field}'."
                    .format(
                        from_field=subject_model._meta.label + '.' + field.name,
                        from_field_type=type(field),
                        to_field=field_name,
                    )
                )
        try:
            field = subject_model._meta.get_field(field_name)
        except FieldDoesNotExist:
            if field_name.endswith("_id"):
                field = subject_model._meta.get_field(field_name[:-3]).target_field
            raise

    field.traversals = tuple(traversals)
    return field


def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=False, suppress_nullrelationshipvalueencountered=False):
    """
    Attempts to extract a field value from ``obj`` matching the ``key`` - which,
    can contain double-underscores (`'__'`) to indicate traversal of relationships
    to related objects.

    For keys that specify ``ForeignKey`` or ``OneToOneField`` field values, full
    related objects are returned by default. If only the primary key values are
    required ((.g. when ordering, or using ``values()`` or ``values_list()``)),
    call the function with ``pk_only=True``.

    By default, ``FieldDoesNotExist`` is raised if the key cannot be mapped to
    a model field. Call the function with ``suppress_fielddoesnotexist=True``
    to instead receive a ``None`` value when this occurs.

    By default, ``NullRelationshipValueEncountered`` is raised if a ``None``
    value is encountered while attempting to traverse relationships in order to
    access further fields. Call the function with
    ``suppress_nullrelationshipvalueencountered`` to instead receive a ``None``
    value when this occurs.
    """
    source = obj
    latest_obj = obj
    segments = key.split(REL_DELIMETER)
    for i, segment in enumerate(segments, start=1):
        if (
            (
                isinstance(source, datetime.datetime)
                and segment in datetime_utils.DATETIMEFIELD_TRANSFORM_EXPRESSIONS
            )
            or (
                isinstance(source, datetime.date)
                and segment in datetime_utils.DATEFIELD_TRANSFORM_EXPRESSIONS
            )
            or (
                isinstance(source, datetime.time)
                and segment in datetime_utils.TIMEFIELD_TRANSFORM_EXPRESSIONS
            )
        ):
            source = datetime_utils.derive_from_value(source, segment)
            value = source
        elif hasattr(source, segment):
            value = getattr(source, segment)
            if isinstance(value, Model):
                latest_obj = value
            if value is None and i < len(segments):
                if suppress_nullrelationshipvalueencountered:
                    return None
                raise NullRelationshipValueEncountered(
                    "'{key}' cannot be reached for {obj} because {model_class}.{field_name} "
                    "is null.".format(
                        key=key,
                        obj=repr(obj),
                        model_class=latest_obj._meta.label,
                        field_name=segment,
                    )
                )
            source = value
        elif suppress_fielddoesnotexist:
            return None
        else:
            raise FieldDoesNotExist(
                "'{name}' is not a valid field name for {model}".format(
                    name=segment, model=type(source)
                )
            )
    if pk_only and hasattr(value, 'pk'):
        return value.pk
    return value


def sort_by_fields(items, fields):
    """
    Sort a list of objects on the given fields. The field list works analogously to
    queryset.order_by(*fields): each field is either a property of the object,
    or is prefixed by '-' (e.g. '-name') to indicate reverse ordering.
    """
    # To get the desired behaviour, we need to order by keys in reverse order
    # See: https://docs.python.org/2/howto/sorting.html#sort-stability-and-complex-sorts
    for key in reversed(fields):
        if key == '?':
            random.shuffle(items)
            continue

        # Check if this key has been reversed
        reverse = False
        if key[0] == '-':
            reverse = True
            key = key[1:]

        def get_sort_value(item):
            # Use a tuple of (v is not None, v) as the key, to ensure that None sorts before other values,
            # as comparing directly with None breaks on python3
            value = extract_field_value(item, key, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
            return (value is not None, value)

        # Sort items
        items.sort(key=get_sort_value, reverse=reverse)