1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
|
from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast
from django.core.exceptions import FieldDoesNotExist
from django.db.models.fields import AutoField, Field
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, NoneType, ProperType, TypeOfAny, UninhabitedType, UnionType, get_proper_type
from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import manytomany
if TYPE_CHECKING:
from django.contrib.contenttypes.fields import GenericForeignKey
def _get_current_field_from_assignment(
ctx: FunctionContext, django_context: DjangoContext
) -> Union["Field[Any, Any]", ForeignObjectRel, "GenericForeignKey"] | None:
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if outer_model_info is None or not helpers.is_model_type(outer_model_info):
return None
field_name = None
for stmt in outer_model_info.defn.defs.body:
if isinstance(stmt, AssignmentStmt):
if stmt.rvalue == ctx.context:
if not isinstance(stmt.lvalues[0], NameExpr):
return None
field_name = stmt.lvalues[0].name
break
if field_name is None:
return None
model_cls = django_context.get_model_class_by_fullname(outer_model_info.fullname)
if model_cls is None:
return None
try:
return model_cls._meta.get_field(field_name)
except FieldDoesNotExist:
return None
def reparametrize_related_field_type(related_field_type: Instance, set_type: MypyType, get_type: MypyType) -> Instance:
args = [
helpers.convert_any_to_type(related_field_type.args[0], set_type),
helpers.convert_any_to_type(related_field_type.args[1], get_type),
]
return related_field_type.copy_modified(args=args)
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
current_field = _get_current_field_from_assignment(ctx, django_context)
if current_field is None:
return AnyType(TypeOfAny.from_error)
assert isinstance(current_field, RelatedField)
try:
related_model_cls = django_context.get_field_related_model_cls(current_field)
except UnregisteredModelError:
return AnyType(TypeOfAny.from_error)
default_related_field_type = set_descriptor_types_for_field(ctx)
# self reference with abstract=True on the model where ForeignKey is defined
current_model_cls = current_field.model
if current_model_cls._meta.abstract and current_model_cls == related_model_cls:
# for all derived non-abstract classes, set variable with this name to
# __get__/__set__ of ForeignKey of derived model
for model_cls in django_context.all_registered_model_classes:
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
if derived_model_info is not None:
fk_ref_type = Instance(derived_model_info, [])
derived_fk_type = reparametrize_related_field_type(
default_related_field_type, set_type=fk_ref_type, get_type=fk_ref_type
)
helpers.add_new_sym_for_info(derived_model_info, name=current_field.name, sym_type=derived_fk_type)
related_model = related_model_cls
related_model_to_set = related_model_cls
if related_model_to_set._meta.proxy_for_model is not None:
related_model_to_set = related_model_to_set._meta.proxy_for_model
typechecker_api = helpers.get_typechecker_api(ctx)
related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model)
related_model_type: ProperType
if related_model_info is None:
# maybe no type stub
related_model_type = AnyType(TypeOfAny.unannotated)
else:
related_model_type = Instance(related_model_info, [])
related_model_to_set_info = helpers.lookup_class_typeinfo(typechecker_api, related_model_to_set)
related_model_to_set_type: ProperType
if related_model_to_set_info is None:
# maybe no type stub
related_model_to_set_type = AnyType(TypeOfAny.unannotated)
else:
related_model_to_set_type = Instance(related_model_to_set_info, [])
# replace Any with referred_to_type
return reparametrize_related_field_type(
default_related_field_type, set_type=related_model_to_set_type, get_type=related_model_type
)
class FieldDescriptorTypes(NamedTuple):
set: MypyType
get: MypyType
def get_field_descriptor_types(
field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool
) -> FieldDescriptorTypes:
set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable)
get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable)
return FieldDescriptorTypes(set=set_type, get=get_type)
def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
current_field = _get_current_field_from_assignment(ctx, django_context)
if current_field is not None:
if isinstance(current_field, AutoField):
return set_descriptor_types_for_field(ctx, is_set_nullable=True)
return set_descriptor_types_for_field(ctx)
def set_descriptor_types_for_field(
ctx: FunctionContext, *, is_set_nullable: bool = False, is_get_nullable: bool = False
) -> Instance:
default_return_type = cast("Instance", ctx.default_return_type)
is_nullable = helpers.get_bool_call_argument_by_name(ctx, "null", default=False)
is_primary_key = helpers.get_bool_call_argument_by_name(ctx, "primary_key", default=False)
# Allow setting field value to `None` when a field is primary key and has a default that can produce a value
default_expr = helpers.get_call_argument_by_name(ctx, "default")
if default_expr is not None:
is_set_nullable = is_primary_key
set_type, get_type = get_field_descriptor_types(
default_return_type.type,
is_set_nullable=is_set_nullable or is_nullable,
is_get_nullable=is_get_nullable or is_nullable,
)
# reconcile set and get types with the base field class
base_field_type = next(base for base in default_return_type.type.mro if base.fullname == fullnames.FIELD_FULLNAME)
mapped_instance = map_instance_to_supertype(default_return_type, base_field_type)
mapped_set_type, mapped_get_type = tuple(get_proper_type(arg) for arg in mapped_instance.args)
# bail if either mapped_set_type or mapped_get_type have type Never
if not (isinstance(mapped_set_type, UninhabitedType) or isinstance(mapped_get_type, UninhabitedType)):
# always replace set_type and get_type with (non-Any) mapped types
set_type = helpers.convert_any_to_type(mapped_set_type, set_type)
get_type = get_proper_type(helpers.convert_any_to_type(mapped_get_type, get_type))
# the get_type must be optional if the field is nullable
if (is_get_nullable or is_nullable) and not (
isinstance(get_type, NoneType) or helpers.is_optional(get_type) or isinstance(get_type, AnyType)
):
ctx.api.fail(
f"{default_return_type.type.name} is nullable but its generic get type parameter is not optional",
ctx.context,
)
return default_return_type.copy_modified(args=[set_type, get_type])
def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = get_proper_type(helpers.get_call_argument_type_by_name(ctx, "base_field"))
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
def drop_combinable(_type: MypyType) -> MypyType | None:
_type = get_proper_type(_type)
if isinstance(_type, Instance) and _type.type.has_base(fullnames.COMBINABLE_EXPRESSION_FULLNAME):
return None
if isinstance(_type, UnionType):
items_without_combinable = []
for item in _type.items:
reduced = drop_combinable(item)
if reduced is not None:
items_without_combinable.append(reduced)
if len(items_without_combinable) > 1:
return UnionType(
items_without_combinable,
line=_type.line,
column=_type.column,
is_evaluated=_type.is_evaluated,
uses_pep604_syntax=_type.uses_pep604_syntax,
)
if len(items_without_combinable) == 1:
return items_without_combinable[0]
return None
return _type
# Both base_field and return type should derive from Field and thus expect 2 arguments
assert len(base_field_arg_type.args) == len(default_return_type.args) == 2
args = []
for new_type, default_arg in zip(base_field_arg_type.args, default_return_type.args, strict=False):
# Drop any base_field Combinable type
reduced = drop_combinable(new_type)
if reduced is None:
ctx.api.fail(
f"Can't have ArrayField expecting {fullnames.COMBINABLE_EXPRESSION_FULLNAME!r} as data type",
ctx.context,
)
else:
new_type = reduced
args.append(helpers.convert_any_to_type(default_arg, new_type))
return default_return_type.copy_modified(args=args)
def transform_into_proper_return_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if outer_model_info is None or not helpers.is_model_type(outer_model_info):
return ctx.default_return_type
assert isinstance(outer_model_info, TypeInfo)
if default_return_type.type.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME):
return manytomany.fill_model_args_for_many_to_many_field(
ctx=ctx, model_info=outer_model_info, django_context=django_context
)
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context)
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx, django_context)
return set_descriptor_types_for_field_callback(ctx, django_context)
|