File: fields.py

package info (click to toggle)
python-django-stubs 5.2.9-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,832 kB
  • sloc: python: 5,185; makefile: 15; sh: 8
file content (250 lines) | stat: -rw-r--r-- 11,274 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast

from django.core.exceptions import FieldDoesNotExist
from django.db.models.fields import AutoField, Field
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, NoneType, ProperType, TypeOfAny, UninhabitedType, UnionType, get_proper_type
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import manytomany

if TYPE_CHECKING:
    from django.contrib.contenttypes.fields import GenericForeignKey


def _get_current_field_from_assignment(
    ctx: FunctionContext, django_context: DjangoContext
) -> Union["Field[Any, Any]", ForeignObjectRel, "GenericForeignKey"] | None:
    outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
    if outer_model_info is None or not helpers.is_model_type(outer_model_info):
        return None

    field_name = None
    for stmt in outer_model_info.defn.defs.body:
        if isinstance(stmt, AssignmentStmt):
            if stmt.rvalue == ctx.context:
                if not isinstance(stmt.lvalues[0], NameExpr):
                    return None
                field_name = stmt.lvalues[0].name
                break
    if field_name is None:
        return None

    model_cls = django_context.get_model_class_by_fullname(outer_model_info.fullname)
    if model_cls is None:
        return None

    try:
        return model_cls._meta.get_field(field_name)
    except FieldDoesNotExist:
        return None


def reparametrize_related_field_type(related_field_type: Instance, set_type: MypyType, get_type: MypyType) -> Instance:
    args = [
        helpers.convert_any_to_type(related_field_type.args[0], set_type),
        helpers.convert_any_to_type(related_field_type.args[1], get_type),
    ]
    return related_field_type.copy_modified(args=args)


def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
    current_field = _get_current_field_from_assignment(ctx, django_context)
    if current_field is None:
        return AnyType(TypeOfAny.from_error)

    assert isinstance(current_field, RelatedField)

    try:
        related_model_cls = django_context.get_field_related_model_cls(current_field)
    except UnregisteredModelError:
        return AnyType(TypeOfAny.from_error)

    default_related_field_type = set_descriptor_types_for_field(ctx)

    # self reference with abstract=True on the model where ForeignKey is defined
    current_model_cls = current_field.model
    if current_model_cls._meta.abstract and current_model_cls == related_model_cls:
        # for all derived non-abstract classes, set variable with this name to
        # __get__/__set__ of ForeignKey of derived model
        for model_cls in django_context.all_registered_model_classes:
            if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
                derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
                if derived_model_info is not None:
                    fk_ref_type = Instance(derived_model_info, [])
                    derived_fk_type = reparametrize_related_field_type(
                        default_related_field_type, set_type=fk_ref_type, get_type=fk_ref_type
                    )
                    helpers.add_new_sym_for_info(derived_model_info, name=current_field.name, sym_type=derived_fk_type)

    related_model = related_model_cls
    related_model_to_set = related_model_cls
    if related_model_to_set._meta.proxy_for_model is not None:
        related_model_to_set = related_model_to_set._meta.proxy_for_model

    typechecker_api = helpers.get_typechecker_api(ctx)

    related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model)
    related_model_type: ProperType
    if related_model_info is None:
        # maybe no type stub
        related_model_type = AnyType(TypeOfAny.unannotated)
    else:
        related_model_type = Instance(related_model_info, [])

    related_model_to_set_info = helpers.lookup_class_typeinfo(typechecker_api, related_model_to_set)
    related_model_to_set_type: ProperType
    if related_model_to_set_info is None:
        # maybe no type stub
        related_model_to_set_type = AnyType(TypeOfAny.unannotated)
    else:
        related_model_to_set_type = Instance(related_model_to_set_info, [])

    # replace Any with referred_to_type
    return reparametrize_related_field_type(
        default_related_field_type, set_type=related_model_to_set_type, get_type=related_model_type
    )


class FieldDescriptorTypes(NamedTuple):
    set: MypyType
    get: MypyType


def get_field_descriptor_types(
    field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool
) -> FieldDescriptorTypes:
    set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable)
    get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable)
    return FieldDescriptorTypes(set=set_type, get=get_type)


def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
    current_field = _get_current_field_from_assignment(ctx, django_context)
    if current_field is not None:
        if isinstance(current_field, AutoField):
            return set_descriptor_types_for_field(ctx, is_set_nullable=True)

    return set_descriptor_types_for_field(ctx)


def set_descriptor_types_for_field(
    ctx: FunctionContext, *, is_set_nullable: bool = False, is_get_nullable: bool = False
) -> Instance:
    default_return_type = cast("Instance", ctx.default_return_type)

    is_nullable = helpers.get_bool_call_argument_by_name(ctx, "null", default=False)
    is_primary_key = helpers.get_bool_call_argument_by_name(ctx, "primary_key", default=False)
    # Allow setting field value to `None` when a field is primary key and has a default that can produce a value
    default_expr = helpers.get_call_argument_by_name(ctx, "default")
    if default_expr is not None:
        is_set_nullable = is_primary_key

    set_type, get_type = get_field_descriptor_types(
        default_return_type.type,
        is_set_nullable=is_set_nullable or is_nullable,
        is_get_nullable=is_get_nullable or is_nullable,
    )

    # reconcile set and get types with the base field class
    base_field_type = next(base for base in default_return_type.type.mro if base.fullname == fullnames.FIELD_FULLNAME)
    mapped_instance = map_instance_to_supertype(default_return_type, base_field_type)
    mapped_set_type, mapped_get_type = tuple(get_proper_type(arg) for arg in mapped_instance.args)

    # bail if either mapped_set_type or mapped_get_type have type Never
    if not (isinstance(mapped_set_type, UninhabitedType) or isinstance(mapped_get_type, UninhabitedType)):
        # always replace set_type and get_type with (non-Any) mapped types
        set_type = helpers.convert_any_to_type(mapped_set_type, set_type)
        get_type = get_proper_type(helpers.convert_any_to_type(mapped_get_type, get_type))

        # the get_type must be optional if the field is nullable
        if (is_get_nullable or is_nullable) and not (
            isinstance(get_type, NoneType) or helpers.is_optional(get_type) or isinstance(get_type, AnyType)
        ):
            ctx.api.fail(
                f"{default_return_type.type.name} is nullable but its generic get type parameter is not optional",
                ctx.context,
            )

    return default_return_type.copy_modified(args=[set_type, get_type])


def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
    default_return_type = set_descriptor_types_for_field(ctx)

    base_field_arg_type = get_proper_type(helpers.get_call_argument_type_by_name(ctx, "base_field"))
    if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
        return default_return_type

    def drop_combinable(_type: MypyType) -> MypyType | None:
        _type = get_proper_type(_type)
        if isinstance(_type, Instance) and _type.type.has_base(fullnames.COMBINABLE_EXPRESSION_FULLNAME):
            return None
        if isinstance(_type, UnionType):
            items_without_combinable = []
            for item in _type.items:
                reduced = drop_combinable(item)
                if reduced is not None:
                    items_without_combinable.append(reduced)

            if len(items_without_combinable) > 1:
                return UnionType(
                    items_without_combinable,
                    line=_type.line,
                    column=_type.column,
                    is_evaluated=_type.is_evaluated,
                    uses_pep604_syntax=_type.uses_pep604_syntax,
                )
            if len(items_without_combinable) == 1:
                return items_without_combinable[0]
            return None

        return _type

    # Both base_field and return type should derive from Field and thus expect 2 arguments
    assert len(base_field_arg_type.args) == len(default_return_type.args) == 2
    args = []
    for new_type, default_arg in zip(base_field_arg_type.args, default_return_type.args, strict=False):
        # Drop any base_field Combinable type
        reduced = drop_combinable(new_type)
        if reduced is None:
            ctx.api.fail(
                f"Can't have ArrayField expecting {fullnames.COMBINABLE_EXPRESSION_FULLNAME!r} as data type",
                ctx.context,
            )
        else:
            new_type = reduced

        args.append(helpers.convert_any_to_type(default_arg, new_type))

    return default_return_type.copy_modified(args=args)


def transform_into_proper_return_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
    default_return_type = get_proper_type(ctx.default_return_type)
    assert isinstance(default_return_type, Instance)

    outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
    if outer_model_info is None or not helpers.is_model_type(outer_model_info):
        return ctx.default_return_type

    assert isinstance(outer_model_info, TypeInfo)

    if default_return_type.type.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME):
        return manytomany.fill_model_args_for_many_to_many_field(
            ctx=ctx, model_info=outer_model_info, django_context=django_context
        )
    if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
        return fill_descriptor_types_for_related_field(ctx, django_context)

    if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
        return determine_type_of_array_field(ctx, django_context)

    return set_descriptor_types_for_field_callback(ctx, django_context)