File: queryset.py

package info (click to toggle)
python-django-modelcluster 6.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 508 kB
  • sloc: python: 5,026; sh: 6; makefile: 5
file content (558 lines) | stat: -rw-r--r-- 17,531 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
from __future__ import unicode_literals

import re

from django.core.exceptions import FieldDoesNotExist
from django.db.models import Model, Q, prefetch_related_objects

from modelcluster.utils import NullRelationshipValueEncountered, extract_field_value, get_model_field, sort_by_fields


# Constructor for test functions that determine whether an object passes some boolean condition
def test_exact(model, attribute_name, value):
    if isinstance(value, Model):
        if value.pk is None:
            # comparing against an unsaved model, so objects need to match by reference
            def _test(obj):
                try:
                    other_value = extract_field_value(obj, attribute_name)
                except NullRelationshipValueEncountered:
                    return False
                return other_value is value

            return _test

        else:
            # comparing against a saved model; objects need to match by type and ID.
            # Additionally, where model inheritance is involved, we need to treat it as a
            # positive match if one is a subclass of the other
            def _test(obj):
                try:
                    other_value = extract_field_value(obj, attribute_name)
                except NullRelationshipValueEncountered:
                    return False
                return value.pk == other_value.pk and (
                    isinstance(value, other_value.__class__)
                    or isinstance(other_value, value.__class__)
                )

            return _test
    else:
        field = get_model_field(model, attribute_name)
        # convert value to the correct python type for this field
        typed_value = field.to_python(value)

        # just a plain Python value = do a normal equality check
        def _test(obj):
            try:
                other_value = extract_field_value(obj, attribute_name)
            except NullRelationshipValueEncountered:
                return False
            return other_value == typed_value

        return _test


def test_iexact(model, attribute_name, match_value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(match_value)

    if match_value is None:

        def _test(obj):
            try:
                val = extract_field_value(obj, attribute_name)
            except NullRelationshipValueEncountered:
                return False
            return val is None
    else:
        match_value = match_value.upper()

        def _test(obj):
            try:
                val = extract_field_value(obj, attribute_name)
            except NullRelationshipValueEncountered:
                return False
            return val is not None and val.upper() == match_value

    return _test


def test_contains(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and match_value in val

    return _test


def test_icontains(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value).upper()

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and match_value in val.upper()

    return _test


def test_lt(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val < match_value

    return _test


def test_lte(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val <= match_value

    return _test


def test_gt(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val > match_value

    return _test


def test_gte(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val >= match_value

    return _test


def test_in(model, attribute_name, value_list):
    field = get_model_field(model, attribute_name)
    match_values = set(field.to_python(val) for val in value_list)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val in match_values

    return _test


def test_startswith(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val.startswith(match_value)

    return _test


def test_istartswith(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value).upper()

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val.upper().startswith(match_value)

    return _test


def test_endswith(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val.endswith(match_value)

    return _test


def test_iendswith(model, attribute_name, value):
    field = get_model_field(model, attribute_name)
    match_value = field.to_python(value).upper()

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and val.upper().endswith(match_value)

    return _test


def test_range(model, attribute_name, range_val):
    field = get_model_field(model, attribute_name)
    start_val = field.to_python(range_val[0])
    end_val = field.to_python(range_val[1])

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return (val is not None and val >= start_val and val <= end_val)

    return _test


def test_isnull(model, attribute_name, sense):
    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        if sense:
            return val is None
        else:
            return val is not None

    return _test


def test_regex(model, attribute_name, regex_string):
    regex = re.compile(regex_string)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and regex.search(val)

    return _test


def test_iregex(model, attribute_name, regex_string):
    regex = re.compile(regex_string, re.I)

    def _test(obj):
        try:
            val = extract_field_value(obj, attribute_name)
        except NullRelationshipValueEncountered:
            return False
        return val is not None and regex.search(val)

    return _test


FILTER_EXPRESSION_TOKENS = {
    'exact': test_exact,
    'iexact': test_iexact,
    'contains': test_contains,
    'icontains': test_icontains,
    'lt': test_lt,
    'lte': test_lte,
    'gt': test_gt,
    'gte': test_gte,
    'in': test_in,
    'startswith': test_startswith,
    'istartswith': test_istartswith,
    'endswith': test_endswith,
    'iendswith': test_iendswith,
    'range': test_range,
    'isnull': test_isnull,
    'regex': test_regex,
    'iregex': test_iregex,
}


def _build_test_function_from_filter(model, key_clauses, val):
    # Translate a filter kwarg rule (e.g. foo__bar__exact=123) into a function which can
    # take a model instance and return a boolean indicating whether it passes the rule
    try:
        get_model_field(model, "__".join(key_clauses))
    except FieldDoesNotExist:
        # it is safe to assume the last clause indicates the type of test
        field_match_found = False
    else:
        field_match_found = True

    if not field_match_found and key_clauses[-1] in FILTER_EXPRESSION_TOKENS:
        constructor = FILTER_EXPRESSION_TOKENS[key_clauses.pop()]
    else:
        constructor = test_exact
    # recombine the remaining items to be interpretted
    # by get_model_field() and extract_field_value()
    attribute_name = "__".join(key_clauses)
    return constructor(model, attribute_name, val)


class FakeQuerySetIterable:
    def __init__(self, queryset):
        self.queryset = queryset


class ModelIterable(FakeQuerySetIterable):
    def __iter__(self):
        yield from self.queryset.results


class DictIterable(FakeQuerySetIterable):
    def __iter__(self):
        field_names = self.queryset.dict_fields or [field.name for field in self.queryset.model._meta.fields]
        for obj in self.queryset.results:
            yield {
                field_name: extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
                for field_name in field_names
            }


class ValuesListIterable(FakeQuerySetIterable):
    def __iter__(self):
        field_names = self.queryset.tuple_fields or [field.name for field in self.queryset.model._meta.fields]
        for obj in self.queryset.results:
            yield tuple([extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) for field_name in field_names])


class FlatValuesListIterable(FakeQuerySetIterable):
    def __iter__(self):
        field_name = self.queryset.tuple_fields[0]
        for obj in self.queryset.results:
            yield extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)


class FakeQuerySet(object):
    def __init__(self, model, results):
        self.model = model
        self.results = results
        self.dict_fields = []
        self.tuple_fields = []
        self.iterable_class = ModelIterable

    def all(self):
        return self

    def get_clone(self, results = None):
        new = FakeQuerySet(self.model, results if results is not None else self.results)
        new.dict_fields = self.dict_fields
        new.tuple_fields = self.tuple_fields
        new.iterable_class = self.iterable_class
        return new
    
    def resolve_q_object(self, q_object):
        connector = q_object.connector
        filters = []

        def test(filters):
            def test_inner(obj):
                result = False
                if connector == Q.AND:
                    result = all([test(obj) for test in filters])
                elif connector == Q.OR:
                    result = any([test(obj) for test in filters])
                else:
                    result = sum([test(obj) for test in filters]) == 1
                if q_object.negated:
                    return not result
                return result
            return test_inner

        for child in q_object.children:
            if isinstance(child, Q):
                filters.append(self.resolve_q_object(child))
            else:
                key_clauses, val = child
                filters.append(_build_test_function_from_filter(self.model, key_clauses.split('__'), val))
        
        return test(filters)

    def _get_filters(self, *args, **kwargs):
        # a list of test functions; objects must pass all tests to be included
        # in the filtered list
        filters = []
        
        for q_object in args:
            filters.append(self.resolve_q_object(q_object))

        for key, val in kwargs.items():
            filters.append(
                _build_test_function_from_filter(self.model, key.split('__'), val)
            )

        return filters

    def filter(self, *args, **kwargs):
        filters = self._get_filters(*args, **kwargs)

        clone = self.get_clone(results=[
            obj for obj in self.results
            if all([test(obj) for test in filters])
        ])
        return clone

    def exclude(self, *args, **kwargs):
        filters = self._get_filters(*args, **kwargs)

        clone = self.get_clone(results=[
            obj for obj in self.results
            if not all([test(obj) for test in filters])
        ])
        return clone

    def get(self, *args, **kwargs):
        clone = self.filter(*args, **kwargs)
        result_count = clone.count()

        if result_count == 0:
            raise self.model.DoesNotExist("%s matching query does not exist." % self.model._meta.object_name)
        elif result_count == 1:
            for result in clone:
                return result
        else:
            raise self.model.MultipleObjectsReturned(
                "get() returned more than one %s -- it returned %s!" % (self.model._meta.object_name, result_count)
            )

    def count(self):
        return len(self.results)

    def exists(self):
        return bool(self.results)

    def first(self):
        for result in self:
            return result

    def last(self):
        if self.results:
            clone = self.get_clone(results=reversed(self.results))
            for result in clone:
                return result

    def select_related(self, *args):
        # has no meaningful effect on non-db querysets
        return self

    def prefetch_related(self, *args):
        prefetch_related_objects(self.results, *args)
        return self

    def only(self, *args):
        # has no meaningful effect on non-db querysets
        return self

    def defer(self, *args):
        # has no meaningful effect on non-db querysets
        return self

    def values(self, *fields):
        clone = self.get_clone()
        clone.dict_fields = fields
        # Ensure all 'fields' are available model fields
        for f in fields:
            get_model_field(self.model, f)
        clone.iterable_class = DictIterable
        return clone

    def values_list(self, *fields, flat=None):
        clone = self.get_clone()
        clone.tuple_fields = fields
        # Ensure all 'fields' are available model fields
        for f in fields:
            get_model_field(self.model, f)
        if flat:
            if len(fields) > 1:
                raise TypeError("'flat' is not valid when values_list is called with more than one field.")
            clone.iterable_class = FlatValuesListIterable
        else:
            clone.iterable_class = ValuesListIterable
        return clone

    def order_by(self, *fields):
        clone = self.get_clone(results=self.results[:])
        sort_by_fields(clone.results, fields)
        return clone

    def distinct(self, *fields):
        unique_results = []
        if not fields:
            fields = [field.name for field in self.model._meta.fields if not field.primary_key]
        seen_keys = set()
        for result in self.results:
            key = tuple(str(extract_field_value(result, field)) for field in fields)
            if key not in seen_keys:
                seen_keys.add(key)
                unique_results.append(result)
        return self.get_clone(results=unique_results)

    # a standard QuerySet will store the results in _result_cache on running the query;
    # this is effectively the same as self.results on a FakeQuerySet, and so we'll make
    # _result_cache an alias of self.results for the benefit of Django internals that
    # exploit it
    def _get_result_cache(self):
        return self.results

    def _set_result_cache(self, val):
        self.results = list(val)

    _result_cache = property(_get_result_cache, _set_result_cache)

    def __getitem__(self, k):
        return self.results[k]

    def __iter__(self):
        iterator = self.iterable_class(self)
        yield from iterator

    def __nonzero__(self):
        return bool(self.results)

    def __repr__(self):
        return repr(list(self))

    def __len__(self):
        return len(self.results)

    ordered = True  # results are returned in a consistent order