File: querysets.py

package info (click to toggle)
python-django-stubs 5.2.9-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,832 kB
  • sloc: python: 5,185; makefile: 15; sh: 8
file content (825 lines) | stat: -rw-r--r-- 35,250 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
from collections.abc import Sequence
from typing import Literal

from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db.models.base import Model
from django.db.models.fields.related import RelatedField
from django.db.models.fields.related_descriptors import (
    ForwardManyToOneDescriptor,
    ManyToManyDescriptor,
    ReverseManyToOneDescriptor,
    ReverseOneToOneDescriptor,
)
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.checker import TypeChecker
from mypy.errorcodes import NO_REDEF
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression, ListExpr, SetExpr, TupleExpr
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, LiteralType, ProperType, TupleType, TypedDictType, TypeOfAny, get_proper_type
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.helpers import DjangoModel
from mypy_django_plugin.transformers.models import get_annotated_type


def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
    default_return_type = get_proper_type(ctx.default_return_type)
    assert isinstance(default_return_type, Instance)

    outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
    if (
        outer_model_info is None
        or not helpers.is_model_type(outer_model_info)
        or outer_model_info.self_type is None
        or not default_return_type.type.is_generic()
    ):
        return default_return_type

    return default_return_type.copy_modified(args=[outer_model_info.self_type])


def get_field_type_from_lookup(
    ctx: MethodContext,
    django_context: DjangoContext,
    model_cls: type[Model],
    *,
    method: str,
    lookup: str,
    silent_on_error: bool = False,
) -> MypyType | None:
    try:
        lookup_field, model_cls = django_context.resolve_lookup_into_field(model_cls, lookup)
    except FieldError as exc:
        if not silent_on_error:
            ctx.api.fail(exc.args[0], ctx.context)
        return None
    except LookupsAreUnsupported:
        return AnyType(TypeOfAny.explicit)

    if lookup_field is None:
        return AnyType(TypeOfAny.implementation_artifact)
    if (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance(
        lookup_field, ForeignObjectRel
    ):
        model_cls = django_context.get_field_related_model_cls(lookup_field)
        lookup_field = django_context.get_primary_key_field(model_cls)

    api = helpers.get_typechecker_api(ctx)
    model_info = helpers.lookup_class_typeinfo(api, model_cls)
    return django_context.get_field_get_type(api, model_info, lookup_field, method=method)


def get_values_list_row_type(
    ctx: MethodContext,
    django_context: DjangoContext,
    model_cls: type[Model],
    *,
    is_annotated: bool,
    flat: bool,
    named: bool,
) -> MypyType:
    field_lookups = resolve_field_lookups(ctx.args[0], django_context)
    if field_lookups is None:
        return AnyType(TypeOfAny.from_error)

    typechecker_api = helpers.get_typechecker_api(ctx)
    model_info = helpers.lookup_class_typeinfo(typechecker_api, model_cls)
    if len(field_lookups) == 0:
        if flat:
            primary_key_field = django_context.get_primary_key_field(model_cls)
            lookup_type = get_field_type_from_lookup(
                ctx, django_context, model_cls, lookup=primary_key_field.attname, method="values_list"
            )
            assert lookup_type is not None
            return lookup_type
        if named:
            column_types: dict[str, MypyType] = {}
            for field in django_context.get_model_fields(model_cls):
                column_type = django_context.get_field_get_type(
                    typechecker_api, model_info, field, method="values_list"
                )
                column_types[field.attname] = column_type
            if is_annotated:
                # Return a NamedTuple with a fallback so that it's possible to access any field
                return helpers.make_oneoff_named_tuple(
                    typechecker_api,
                    "Row",
                    column_types,
                    extra_bases=[typechecker_api.named_generic_type(fullnames.ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])],
                )
            return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
        # flat=False, named=False, all fields
        if is_annotated:
            return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
        field_lookups = []
        for field in django_context.get_model_fields(model_cls):
            field_lookups.append(field.attname)

    if len(field_lookups) > 1 and flat:
        typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
        return AnyType(TypeOfAny.from_error)

    column_types = {}
    for field_lookup in field_lookups:
        lookup_field_type = get_field_type_from_lookup(
            ctx, django_context, model_cls, lookup=field_lookup, method="values_list", silent_on_error=is_annotated
        )
        if lookup_field_type is None:
            if is_annotated:
                lookup_field_type = AnyType(TypeOfAny.from_omitted_generics)
            else:
                return AnyType(TypeOfAny.from_error)
        column_types[field_lookup] = lookup_field_type

    if flat:
        assert len(column_types) == 1
        row_type = next(iter(column_types.values()))
    elif named:
        row_type = helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
    else:
        # Since there may have been repeated field lookups, we cannot just use column_types.values here.
        # This is not the case in named above, because Django will error if duplicate fields are requested.
        resolved_column_types = [column_types[field_lookup] for field_lookup in field_lookups]
        row_type = helpers.make_tuple(typechecker_api, resolved_column_types)

    return row_type


def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
    django_model = helpers.get_model_info_from_qs_ctx(ctx, django_context)
    if django_model is None:
        return ctx.default_return_type

    default_return_type = get_proper_type(ctx.default_return_type)
    if not isinstance(default_return_type, Instance):
        return ctx.default_return_type

    flat = helpers.get_bool_call_argument_by_name(ctx, "flat", default=False)
    named = helpers.get_bool_call_argument_by_name(ctx, "named", default=False)

    if flat and named:
        ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
        return default_return_type.copy_modified(args=[django_model.typ, AnyType(TypeOfAny.from_error)])

    row_type = get_values_list_row_type(
        ctx, django_context, django_model.cls, is_annotated=django_model.is_annotated, flat=flat, named=named
    )
    ret = default_return_type.copy_modified(args=[django_model.typ, row_type])
    if not named and (field_lookups := resolve_field_lookups(ctx.args[0], django_context)):
        # For non-named values_list, the row type does not encode column names.
        # Attach selected field names to the returned QuerySet instance so that
        # subsequent annotate() can make an informed decision about name conflicts.
        ret.extra_attrs = helpers.merge_extra_attrs(ret.extra_attrs, new_immutable=set(field_lookups))
    return ret


def gather_kwargs(ctx: MethodContext) -> dict[str, MypyType] | None:
    num_args = len(ctx.arg_kinds)
    kwargs = {}
    named = (ARG_NAMED, ARG_NAMED_OPT)
    for i in range(num_args):
        if not ctx.arg_kinds[i]:
            continue
        if any(kind not in named for kind in ctx.arg_kinds[i]):
            # Only named arguments supported
            continue
        for j in range(len(ctx.arg_names[i])):
            name = ctx.arg_names[i][j]
            assert name is not None
            kwargs[name] = ctx.arg_types[i][j]
    return kwargs


def gather_expression_types(ctx: MethodContext) -> dict[str, MypyType]:
    kwargs = gather_kwargs(ctx)
    if not kwargs:
        return {}

    # For now, we don't try to resolve the output_field of the field would be, but use Any.
    # NOTE: It's important that we use 'special_form' for 'Any' as otherwise we can
    # get stuck with mypy interpreting an overload ambiguity towards the
    # overloaded 'Field.__get__' method when its 'model' argument gets matched. This
    # is because the model argument gets matched with a model subclass that is
    # parametrized with a type that contains the 'Any' below and then mypy runs in
    # to a (false?) ambiguity, due to 'Any', and can't decide what overload/return
    # type to select
    #
    # Example:
    #   class MyModel(models.Model):
    #       field = models.CharField()
    #
    #   # Our plugin auto generates the following subclass
    #   class MyModel_WithAnnotations(MyModel, django_stubs_ext.Annotations[_Annotations]):
    #       ...
    #   # Assume
    #   x = MyModel.objects.annotate(foo=F("id")).get()
    #   reveal_type(x)  # MyModel_WithAnnotations[TypedDict({"foo": Any})]
    #   # Then, on an attribute access of 'field' like
    #   reveal_type(x.field)
    #
    # Now 'CharField.__get__', which is overloaded, is passed 'x' as the 'model'
    # argument and mypy consider it ambiguous to decide which overload method to
    # select due to the 'Any' in 'TypedDict({"foo": Any})'. But if we specify the
    # 'Any' as 'TypeOfAny.special_form' mypy doesn't consider the model instance to
    # contain 'Any' and the ambiguity goes away.
    return {name: AnyType(TypeOfAny.special_form) for name, _ in kwargs.items()}


def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
    django_model = helpers.get_model_info_from_qs_ctx(ctx, django_context)
    if django_model is None:
        return AnyType(TypeOfAny.from_omitted_generics)

    default_return_type = get_proper_type(ctx.default_return_type)
    if not isinstance(default_return_type, Instance):
        return ctx.default_return_type

    api = helpers.get_typechecker_api(ctx)

    expression_types = {
        attr_name: typ
        for attr_name, typ in gather_expression_types(ctx).items()
        if check_valid_attr_value(ctx, django_context, django_model, attr_name)
    }

    annotated_type: ProperType = django_model.typ
    if expression_types:
        fields_dict = helpers.make_typeddict(
            api,
            fields=expression_types,
            required_keys=set(expression_types.keys()),
            readonly_keys=set(),
        )
        annotated_type = get_annotated_type(api, django_model.typ, fields_dict=fields_dict)

    row_type: MypyType
    if len(default_return_type.args) > 1:
        original_row_type = get_proper_type(default_return_type.args[1])
        row_type = original_row_type
        if isinstance(original_row_type, TypedDictType):
            row_type = api.named_generic_type(
                "builtins.dict", [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)]
            )
        elif isinstance(original_row_type, TupleType):
            if original_row_type.partial_fallback.type.has_base("typing.NamedTuple"):
                # TODO: Use a NamedTuple which contains the known fields, but also
                #  falls back to allowing any attribute access.
                row_type = AnyType(TypeOfAny.implementation_artifact)
            else:
                row_type = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)])
        elif isinstance(original_row_type, Instance) and helpers.is_model_type(original_row_type.type):
            row_type = annotated_type
    else:
        row_type = annotated_type
    return default_return_type.copy_modified(args=[annotated_type, row_type])


def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> list[str] | None:
    field_lookups = []
    for field_lookup_expr in lookup_exprs:
        field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, django_context)
        if field_lookup is None:
            return None
        field_lookups.append(field_lookup)
    return field_lookups


def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
    """
    Extract proper return type for QuerySet.values(*fields, **expressions) method calls.

    See https://docs.djangoproject.com/en/5.2/ref/models/querysets/#values
    """
    django_model = helpers.get_model_info_from_qs_ctx(ctx, django_context)
    if django_model is None or django_model.is_annotated:
        return ctx.default_return_type

    default_return_type = get_proper_type(ctx.default_return_type)
    if not isinstance(default_return_type, Instance):
        return ctx.default_return_type

    field_lookups = resolve_field_lookups(ctx.args[0], django_context)
    if field_lookups is None:
        return AnyType(TypeOfAny.from_error)

    # Bare `.values()` case
    if len(field_lookups) == 0 and not ctx.args[1]:
        for field in django_context.get_model_fields(django_model.cls):
            field_lookups.append(field.attname)

    column_types: dict[str, MypyType] = {}

    # Collect `*fields` types -- `.values("id", "name")`
    for field_lookup in field_lookups:
        field_lookup_type = get_field_type_from_lookup(
            ctx, django_context, django_model.cls, lookup=field_lookup, method="values"
        )
        if field_lookup_type is None:
            return default_return_type.copy_modified(args=[django_model.typ, AnyType(TypeOfAny.from_error)])

        column_types[field_lookup] = field_lookup_type

    # Collect `**expressions` types -- `.values(lower_name=Lower("name"), foo=F("name"))`
    column_types.update(gather_expression_types(ctx))
    row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()), set())
    return default_return_type.copy_modified(args=[django_model.typ, row_type])


def _infer_prefetch_queryset_type(queryset_expr: Expression, api: TypeChecker) -> Instance | None:
    """Infer the model Instance from `Prefetch(queryset=...)`"""
    try:
        qs_type = get_proper_type(api.expr_checker.accept(queryset_expr))
    except Exception:
        return None
    if isinstance(qs_type, Instance):
        return qs_type
    return None


def _resolve_prefetch_string_argument(
    type_arg: MypyType,
    expr: Expression | None,
    django_context: DjangoContext,
    arg_name: str,
) -> str | None:
    # First try to get value from specialized type arg
    arg_value = helpers.get_literal_str_type(type_arg)
    if arg_value is not None:
        return arg_value

    # Fallback: parse inline call expression
    if isinstance(expr, CallExpr):
        arg_expr = helpers.get_class_init_argument_by_name(expr, arg_name)
        if arg_expr:
            return helpers.resolve_string_attribute_value(arg_expr, django_context)
    return None


def _resolve_prefetch_queryset_argument(
    type_arg: MypyType,
    expr: Expression | None,
    api: TypeChecker,
) -> Instance | None:
    # First try to get queryset type from specialized type arg
    queryset_type = get_proper_type(type_arg)
    if isinstance(queryset_type, Instance):
        elem_model = helpers.extract_model_type_from_queryset(queryset_type, api)
        # If we got a valid specific model type, return the queryset type
        if elem_model is not None and elem_model.type.fullname != fullnames.MODEL_CLASS_FULLNAME:
            return queryset_type

    # Fallback: parse inline call expression
    if isinstance(expr, CallExpr):
        queryset_expr = helpers.get_class_init_argument_by_name(expr, "queryset")
        if queryset_expr is not None:
            return _infer_prefetch_queryset_type(queryset_expr, api)

    return None


def _specialize_string_arg_to_literal(
    ctx: FunctionContext, django_context: DjangoContext, arg_name: str, fallback_type: MypyType
) -> MypyType:
    """
    Helper to specialize a string argument to a Literal[str] type.

    This allows the plugin to extract the actual string value for further processing
    and validation in later analysis phases.
    """

    if arg_expr := helpers.get_call_argument_by_name(ctx, arg_name):
        if arg_value := helpers.resolve_string_attribute_value(arg_expr, django_context):
            api = helpers.get_typechecker_api(ctx)
            return LiteralType(value=arg_value, fallback=api.named_generic_type("builtins.str", []))

    return fallback_type


def specialize_prefetch_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
    """Function hook for `Prefetch(...)` to specialize its `lookup` and `to_attr` generic parameters."""
    default = get_proper_type(ctx.default_return_type)
    if not isinstance(default, Instance):
        return ctx.default_return_type

    lookup_type = _specialize_string_arg_to_literal(ctx, django_context, "lookup", default.args[0])
    to_attr_type = _specialize_string_arg_to_literal(ctx, django_context, "to_attr", default.args[2])

    return default.copy_modified(args=[lookup_type, default.args[1], to_attr_type])


def gather_flat_args(ctx: MethodContext) -> list[tuple[Expression | None, ProperType]]:
    """
    Flatten all arguments into a uniform list of (expr, typ) pairs.

    This helper iterates over positional and named arguments and expands any starred
    arguments when their type is a TupleType with statically known items.
    """
    lookups: list[tuple[Expression | None, ProperType]] = []
    arg_start_idx = 0
    for expr, typ, kind in zip(ctx.args[0], ctx.arg_types[0], ctx.arg_kinds[0], strict=False):
        ptyp = get_proper_type(typ)
        if kind == ARG_STAR:
            # Expand starred tuple items if statically known
            if isinstance(ptyp, TupleType):
                lookups.append((None, get_proper_type(ptyp.items[arg_start_idx])))
            # If not a TupleType (e.g. list/Iterable), we cannot expand statically
            arg_start_idx += 1
            continue
        lookups.append((expr, ptyp))
    return lookups


def _get_selected_fields_from_queryset_type(qs_type: Instance) -> set[str] | None:
    """
    Derive selected field names from a QuerySet type.

    Sources:
      - values(): encoded in the row TypedDict keys
      - values_list(named=True): row is a NamedTuple; extract field names from fallback TypeInfo
      - values_list(named=False): stored in qs_type.extra_attrs.immutable
    """
    if len(qs_type.args) > 1:
        row_type = get_proper_type(qs_type.args[1])
        if isinstance(row_type, Instance) and helpers.is_model_type(row_type.type):
            return None
        if isinstance(row_type, TypedDictType):
            return set(row_type.items.keys())
        if isinstance(row_type, TupleType):
            if row_type.partial_fallback.type.has_base("typing.NamedTuple"):
                return {name for name, sym in row_type.partial_fallback.type.names.items() if sym.plugin_generated}
            return set()
        return set()

    # Fallback to explicit metadata attached to the QuerySet Instance
    if qs_type.extra_attrs and qs_type.extra_attrs.immutable and isinstance(qs_type.extra_attrs.immutable, set):
        return qs_type.extra_attrs.immutable

    return None


def check_valid_attr_value(
    ctx: MethodContext,
    django_context: DjangoContext,
    model: DjangoModel,
    attr_name: str,
    *,
    new_attr_names: set[str] | None = None,
) -> bool:
    """
    Check if adding `attr_name` would conflict with existing symbols on `model`.

    Args:
        - model: The Django model being analyzed
        - attr_name: The name of the attribute to be added
        - new_attr_names: A mapping of field names to types currently being added to the model
    """
    deselected_fields: set[str] | None = None
    if isinstance(ctx.type, Instance):
        selected_fields = _get_selected_fields_from_queryset_type(ctx.type)
        if selected_fields is not None:
            model_field_names = {f.name for f in django_context.get_model_fields(model.cls)}
            deselected_fields = model_field_names - selected_fields
            new_attr_names = new_attr_names or set()
            new_attr_names.update(selected_fields - model_field_names)

    is_conflicting_attr_value = bool(
        # 1. Conflict with another symbol on the model (If not de-selected via a prior .values/.values_list call).
        # Ex:
        #     User.objects.prefetch_related(Prefetch(..., to_attr="id"))
        (model.typ.type.get(attr_name) and (deselected_fields is None or attr_name not in deselected_fields))
        # 2. Conflict with a previous annotation.
        # Ex:
        #     User.objects.annotate(foo=...).prefetch_related(Prefetch(...,to_attr="foo"))
        #     User.objects.prefetch_related(Prefetch(...,to_attr="foo")).prefetch_related(Prefetch(...,to_attr="foo"))
        or (model.typ.extra_attrs and attr_name in model.typ.extra_attrs.attrs)
        # 3. Conflict with another symbol added in the current processing.
        # Ex:
        #     User.objects.prefetch_related(
        #        Prefetch("groups", Group.objects.filter(name="test"), to_attr="new_attr"),
        #        Prefetch("groups", Group.objects.all(), to_attr="new_attr"), # E: Not OK!
        #     )
        or (new_attr_names is not None and attr_name in new_attr_names)
    )
    if is_conflicting_attr_value:
        ctx.api.fail(
            f'Attribute "{attr_name}" already defined on "{model.typ}"',
            ctx.context,
            code=NO_REDEF,
        )
    return not is_conflicting_attr_value


def check_valid_prefetch_related_lookup(
    ctx: MethodContext,
    lookup: str,
    django_model: DjangoModel,
    django_context: DjangoContext,
    *,
    is_generic_prefetch: bool = False,
) -> bool:
    """Check if a lookup string resolve to something that can be prefetched"""
    current_model_cls = django_model.cls
    contenttypes_installed = django_context.apps_registry.is_installed("django.contrib.contenttypes")
    for through_attr in lookup.split("__"):
        rel_obj_descriptor = getattr(current_model_cls, through_attr, None)
        if rel_obj_descriptor is None:
            ctx.api.fail(
                (
                    f'Cannot find "{through_attr}" on "{current_model_cls.__name__}" object, '
                    f'"{lookup}" is an invalid parameter to "prefetch_related()"'
                ),
                ctx.context,
            )
            return False
        if contenttypes_installed and is_generic_prefetch:
            from django.contrib.contenttypes.fields import GenericForeignKey

            if not isinstance(rel_obj_descriptor, GenericForeignKey):
                ctx.api.fail(
                    f'"{through_attr}" on "{current_model_cls.__name__}" is not a GenericForeignKey, '
                    f"GenericPrefetch can only be used with GenericForeignKey fields",
                    ctx.context,
                )
                return True
        elif isinstance(rel_obj_descriptor, ForwardManyToOneDescriptor):
            current_model_cls = rel_obj_descriptor.field.remote_field.model
        elif isinstance(rel_obj_descriptor, ReverseOneToOneDescriptor):
            current_model_cls = rel_obj_descriptor.related.related_model  # type:ignore[assignment] # Can't be 'self' for non abstract models
        elif isinstance(rel_obj_descriptor, ManyToManyDescriptor):
            current_model_cls = (
                rel_obj_descriptor.rel.related_model if rel_obj_descriptor.reverse else rel_obj_descriptor.rel.model  # type:ignore[assignment] # Can't be 'self' for non abstract models
            )
        elif isinstance(rel_obj_descriptor, ReverseManyToOneDescriptor):
            if contenttypes_installed:
                from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor

                if isinstance(rel_obj_descriptor, ReverseGenericManyToOneDescriptor):
                    current_model_cls = rel_obj_descriptor.rel.model
                    continue
            current_model_cls = rel_obj_descriptor.rel.related_model  # type:ignore[assignment] # Can't be 'self' for non abstract models
        else:
            if contenttypes_installed:
                from django.contrib.contenttypes.fields import GenericForeignKey

                if isinstance(rel_obj_descriptor, GenericForeignKey):
                    # Generic foreign keys can point to any model, so we use Model as the base type
                    return True
            ctx.api.fail(
                (
                    f'"{lookup}" does not resolve to an item that supports prefetching '
                    '- this is an invalid parameter to "prefetch_related()"'
                ),
                ctx.context,
            )
            return False
    return True


def check_conflicting_lookups(
    ctx: MethodContext,
    observed_attr: str,
    qs_types: dict[str, Instance | None],
    queryset_type: Instance | None,
) -> bool:
    is_conflicting_lookup = bool(observed_attr in qs_types and qs_types[observed_attr] != queryset_type)
    if is_conflicting_lookup:
        ctx.api.fail(
            f'Lookup "{observed_attr}" was already seen with a different queryset',
            ctx.context,
            code=NO_REDEF,
        )
    return is_conflicting_lookup


def extract_prefetch_related_annotations(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
    """
    Extract annotated attributes via `prefetch_related(Prefetch(..., to_attr=...))`

    See https://docs.djangoproject.com/en/5.2/ref/models/querysets/#prefetch-objects
    """
    api = helpers.get_typechecker_api(ctx)

    if not (
        isinstance(ctx.type, Instance)
        and isinstance((default_return_type := get_proper_type(ctx.default_return_type)), Instance)
        and (qs_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is not None
        and ctx.args
        and ctx.arg_types
        and ctx.arg_types[0]
        # Only process the correct overload, i.e.
        #     def prefetch_related(self, *lookups: str | Prefetch[_PrefetchedQuerySetT, _ToAttrT]) -> Self: ...
        and None not in ctx.callee_arg_names
    ):
        return ctx.default_return_type

    new_attrs: dict[str, MypyType] = {}  # A mapping of field_name / types to add to the model
    qs_types: dict[str, Instance | None] = {}  # A mapping of field_name / associated queryset type
    for expr, typ in gather_flat_args(ctx):
        if not (isinstance(typ, Instance) and typ.type.has_base(fullnames.PREFETCH_CLASS_FULLNAME)):
            # Handle plain string lookups (not Prefetch instances)
            lookup = helpers.get_literal_str_type(typ)
            queryset_type = None
            if lookup is not None:
                check_valid_prefetch_related_lookup(ctx, lookup, qs_model, django_context)
                check_conflicting_lookups(ctx, lookup, qs_types, queryset_type)
                qs_types[lookup] = queryset_type
            continue

        # 1) Extract lookup value from specialized type arg or call expression
        lookup = _resolve_prefetch_string_argument(typ.args[0], expr, django_context, "lookup")

        # 2) Extract to_attr value from specialized type arg or call expression
        to_attr = _resolve_prefetch_string_argument(typ.args[2], expr, django_context, "to_attr")
        if to_attr is None and lookup is None:
            continue

        # 3.a) Determine queryset type from specialized type arg or call expression
        queryset_type = _resolve_prefetch_queryset_argument(typ.args[1], expr, api)

        # 3.b) Extract model type from queryset type (or from the lookup value)
        elem_model: Instance | None = None
        if queryset_type is not None and isinstance(queryset_type, Instance):
            elem_model = helpers.extract_model_type_from_queryset(queryset_type, api)
        elif lookup:
            try:
                observed_model_cls = django_context.resolve_lookup_into_field(qs_model.cls, lookup)[1]
                if model_info := helpers.lookup_class_typeinfo(api, observed_model_cls):
                    elem_model = Instance(model_info, [])
            except (FieldError, LookupsAreUnsupported):
                pass

        if to_attr and check_valid_attr_value(
            ctx, django_context, qs_model, to_attr, new_attr_names=set(new_attrs.keys())
        ):
            new_attrs[to_attr] = api.named_generic_type(
                "builtins.list",
                [elem_model if elem_model is not None else AnyType(TypeOfAny.special_form)],
            )
            qs_types[to_attr] = queryset_type
        if not to_attr and lookup:
            check_valid_prefetch_related_lookup(
                ctx,
                lookup,
                qs_model,
                django_context,
                is_generic_prefetch=typ.type.has_base(fullnames.GENERIC_PREFETCH_CLASS_FULLNAME),
            )
            check_conflicting_lookups(ctx, lookup, qs_types, queryset_type)
            qs_types[lookup] = queryset_type

    if not new_attrs:
        return ctx.default_return_type

    fields_dict = helpers.make_typeddict(
        api,
        fields=new_attrs,
        required_keys=set(new_attrs.keys()),
        readonly_keys=set(),
    )

    annotated_model = get_annotated_type(api, qs_model.typ, fields_dict=fields_dict)

    # Keep row shape; if row is a model instance, update it to annotated
    # Todo: consolidate with `extract_proper_type_queryset_annotate` row handling above.
    if len(default_return_type.args) > 1:
        original_row = get_proper_type(default_return_type.args[1])
        row_type: MypyType = original_row
        if isinstance(original_row, Instance) and helpers.is_model_type(original_row.type):
            row_type = annotated_model
    else:
        row_type = annotated_model

    return default_return_type.copy_modified(args=[annotated_model, row_type])


def _get_select_related_field_choices(model_cls: type[Model]) -> set[str]:
    """
    Get valid field choices for select_related lookups.
    Based on Django's SQLCompiler.get_related_selections._get_field_choices method.
    """
    opts = model_cls._meta

    # Direct relation fields (forward relations)
    direct_choices = (f.name for f in opts.fields if f.is_relation)

    # Reverse relation fields (backward relations with unique=True)
    reverse_choices = (f.field.related_query_name() for f in opts.related_objects if f.field.unique)
    return {*direct_choices, *reverse_choices}


def _validate_select_related_lookup(
    ctx: MethodContext,
    django_context: DjangoContext,
    model_cls: type[Model],
    lookup: str,
) -> bool:
    """Validate a single select_related lookup string."""
    if not lookup.strip():
        ctx.api.fail(
            f'Invalid field name "{lookup}" in select_related lookup',
            ctx.context,
        )
        return False

    lookup_parts = lookup.split("__")
    observed_model = model_cls
    for i, part in enumerate(lookup_parts):
        valid_choices = _get_select_related_field_choices(observed_model)

        if part not in valid_choices:
            ctx.api.fail(
                f'Invalid field name "{part}" in select_related lookup. '
                f"Choices are: {', '.join(sorted(valid_choices)) or '(none)'}",
                ctx.context,
            )
            return False

        if i < len(lookup_parts) - 1:  # Not the last part
            try:
                field, observed_model = django_context.resolve_lookup_into_field(observed_model, part)
                if field is None:
                    return False
            except (FieldError, LookupsAreUnsupported):
                # For good measure, but we should never reach this since we already validated the part name
                return False

    return True


def validate_select_related(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
    """
    Validates that all lookup strings passed to select_related() resolve to actual model fields and relations.

    Extracted and adapted from `django.db.models.sql.compiler.SQLCompiler.get_related_selections`
    """
    if not (
        isinstance(ctx.type, Instance)
        and (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is not None
        and ctx.arg_types
        and ctx.arg_types[0]
    ):
        return ctx.default_return_type

    for lookup_type in ctx.arg_types[0]:
        lookup_value = helpers.get_literal_str_type(get_proper_type(lookup_type))
        if lookup_value is not None:
            _validate_select_related_lookup(ctx, django_context, django_model.cls, lookup_value)

    return ctx.default_return_type


def _validate_bulk_update_field(
    ctx: MethodContext, model_cls: type[Model], field_name: str, method: Literal["bulk_update", "abulk_update"]
) -> bool:
    opts = model_cls._meta
    try:
        field = opts.get_field(field_name)
    except FieldDoesNotExist as e:
        ctx.api.fail(str(e), ctx.context)
        return False

    if not field.concrete or field.many_to_many:
        ctx.api.fail(f'"{method}()" can only be used with concrete fields. Got "{field_name}"', ctx.context)
        return False

    all_pk_fields = set(getattr(opts, "pk_fields", [opts.pk]))
    for parent in getattr(opts, "all_parents", opts.get_parent_list()):
        all_pk_fields.update(getattr(parent._meta, "pk_fields", [parent._meta.pk]))

    if field in all_pk_fields:
        ctx.api.fail(f'"{method}()" cannot be used with primary key fields. Got "{field_name}"', ctx.context)
        return False

    return True


def validate_bulk_update(
    ctx: MethodContext, django_context: DjangoContext, method: Literal["bulk_update", "abulk_update"]
) -> MypyType:
    """
    Type check the `fields` argument passed to `QuerySet.bulk_update(...)`.

    Extracted and adapted from `django.db.models.query.QuerySet.bulk_update`
    Mirrors tests from `django/tests/queries/test_bulk_update.py`
    """
    if not (
        isinstance(ctx.type, Instance)
        and (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is not None
        and len(ctx.args) >= 2
        and ctx.args[1]
        and isinstance((fields_args := ctx.args[1][0]), (ListExpr, TupleExpr, SetExpr))
    ):
        return ctx.default_return_type

    if len(fields_args.items) == 0:
        ctx.api.fail(f'Field names must be given to "{method}()"', ctx.context)
        return ctx.default_return_type

    for field_arg in fields_args.items:
        field_name = helpers.resolve_string_attribute_value(field_arg, django_context)
        if field_name is not None:
            _validate_bulk_update_field(ctx, django_model.cls, field_name, method)

    return ctx.default_return_type