1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
|
from __future__ import annotations
import dataclasses
import functools
import itertools
import weakref
from collections.abc import Iterable
from typing import (
TYPE_CHECKING,
Any,
cast,
)
from django.db.models.query import Prefetch, QuerySet
from django.db.models.sql.where import WhereNode
from strawberry import Schema
from strawberry.types import has_object_definition
from strawberry.types.base import (
StrawberryContainer,
StrawberryObjectDefinition,
StrawberryType,
StrawberryTypeVar,
)
from strawberry.types.lazy_type import LazyType
from strawberry.types.union import StrawberryUnion
from strawberry.utils.str_converters import to_camel_case
from typing_extensions import TypeIs, assert_never
from strawberry_django.fields.types import resolve_model_field_name
from .pyutils import DictTree, dicttree_insersection_differs, dicttree_merge
from .typing import get_django_definition
if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from django.db import models
from django.db.models.expressions import Expression
from django.db.models.fields import Field
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from model_utils.managers import (
InheritanceManagerMixin,
InheritanceQuerySetMixin,
)
from polymorphic.models import PolymorphicModel
@functools.lru_cache
def get_model_fields(
model: type[models.Model],
*,
camel_case: bool = False,
is_input: bool = False,
is_filter: bool = False,
) -> dict[str, Field | ForeignObjectRel]:
"""Get a list of model fields from the model."""
fields = {}
for f in model._meta.get_fields():
name = cast(
"str",
resolve_model_field_name(f, is_input=is_input, is_filter=is_filter),
)
if camel_case:
name = to_camel_case(name)
fields[name] = f
return fields
def get_model_field(
model: type[models.Model],
field_name: str,
*,
camel_case: bool = False,
is_input: bool = False,
is_filter: bool = False,
) -> Field | ForeignObjectRel | None:
"""Get a model fields from the model given its name."""
return get_model_fields(
model,
camel_case=camel_case,
is_input=is_input,
is_filter=is_filter,
).get(field_name)
def get_possible_types(
gql_type: StrawberryObjectDefinition | StrawberryType | type,
*,
object_definition: StrawberryObjectDefinition | None = None,
) -> Generator[type]:
"""Resolve all possible types for gql_type.
Args:
----
gql_type:
The type to retrieve possibilities from.
type_def:
Optional type definition to use to resolve type vars.
This is used internally.
Yields:
------
All possibilities for the type
"""
if isinstance(gql_type, StrawberryObjectDefinition):
yield from get_possible_types(gql_type.origin, object_definition=gql_type)
elif isinstance(gql_type, LazyType):
yield from get_possible_types(gql_type.resolve_type())
elif isinstance(gql_type, StrawberryTypeVar) and object_definition is not None:
resolved = object_definition.type_var_map.get(gql_type.type_var.__name__, None)
if resolved is not None:
yield from get_possible_types(resolved)
elif isinstance(gql_type, StrawberryContainer):
yield from get_possible_types(gql_type.of_type)
elif isinstance(gql_type, StrawberryUnion):
yield from itertools.chain.from_iterable(
(get_possible_types(t) for t in gql_type.types),
)
elif isinstance(gql_type, StrawberryType):
# Nothing to return here
pass
elif isinstance(gql_type, type):
yield gql_type
else:
assert_never(gql_type)
def get_possible_type_definitions(
gql_type: StrawberryObjectDefinition | StrawberryType | type,
) -> Generator[StrawberryObjectDefinition]:
"""Resolve all possible type definitions for gql_type.
Args:
----
gql_type:
The type to retrieve possibilities from.
Yields:
------
All possibilities for the type
"""
if isinstance(gql_type, StrawberryObjectDefinition):
yield gql_type
return
for t in get_possible_types(gql_type):
if isinstance(t, StrawberryObjectDefinition):
yield t
elif has_object_definition(t):
yield t.__strawberry_definition__
try:
# Can't import PolymorphicModel, because it requires Django Apps to be ready
# Import polymorphic instead to check for its existence
import polymorphic # noqa: F401
def is_polymorphic_model(v: type) -> TypeIs[type[PolymorphicModel]]:
return getattr(v, "polymorphic_model_marker", False) is True
except ImportError:
def is_polymorphic_model(v: type) -> TypeIs[type[PolymorphicModel]]:
return False
try:
from model_utils.managers import InheritanceManagerMixin, InheritanceQuerySetMixin
def is_inheritance_manager(
v: Any,
) -> TypeIs[InheritanceManagerMixin]:
return isinstance(v, InheritanceManagerMixin)
def is_inheritance_qs(
v: Any,
) -> TypeIs[InheritanceQuerySetMixin]:
return isinstance(v, InheritanceQuerySetMixin)
except ImportError:
def is_inheritance_manager(
v: Any,
) -> TypeIs[InheritanceManagerMixin]:
return False
def is_inheritance_qs(
v: Any,
) -> TypeIs[InheritanceQuerySetMixin]:
return False
def _can_optimize_subtypes(model: type[models.Model]) -> bool:
return is_polymorphic_model(model) or is_inheritance_manager(model._default_manager)
_interfaces: weakref.WeakKeyDictionary[
Schema,
dict[StrawberryObjectDefinition, list[StrawberryObjectDefinition]],
] = weakref.WeakKeyDictionary()
def get_possible_concrete_types(
model: type[models.Model],
schema: Schema,
strawberry_type: StrawberryObjectDefinition | StrawberryType,
) -> Iterable[StrawberryObjectDefinition]:
"""Return the object definitions the optimizer should look at when optimizing a model.
Returns any object definitions attached to either the model or one of its supertypes.
If the model is one that supports polymorphism, by returning subtypes from its queryset, subtypes are also
looked at. Currently, this is supported for django-polymorphic and django-model-utils InheritanceManager.
"""
for object_definition in get_possible_type_definitions(strawberry_type):
if not object_definition.is_interface:
yield object_definition
continue
schema_interfaces = _interfaces.setdefault(schema, {})
interface_definitions = schema_interfaces.get(object_definition)
if interface_definitions is None:
interface_definitions = []
for t in schema.schema_converter.type_map.values():
t_definition = t.definition
if isinstance(t_definition, StrawberryObjectDefinition) and issubclass(
t_definition.origin, object_definition.origin
):
interface_definitions.append(t_definition)
schema_interfaces[object_definition] = interface_definitions
for interface_definition in interface_definitions:
dj_definition = get_django_definition(interface_definition.origin)
if dj_definition and (
issubclass(model, dj_definition.model)
or (
_can_optimize_subtypes(model)
and issubclass(dj_definition.model, model)
)
):
yield interface_definition
@dataclasses.dataclass(eq=True)
class PrefetchInspector:
"""Prefetch hints."""
prefetch: Prefetch
qs: QuerySet = dataclasses.field(init=False, compare=False)
query: Query = dataclasses.field(init=False, compare=False)
def __post_init__(self):
self.qs = cast("QuerySet", self.prefetch.queryset) # type: ignore
self.query = self.qs.query
@property
def only(self) -> frozenset[str] | None:
if self.query.deferred_loading[1]:
return None
return frozenset(self.query.deferred_loading[0])
@only.setter
def only(self, value: Iterable[str | None] | None):
value = frozenset(v for v in (value or []) if v is not None)
self.query.deferred_loading = (value, len(value) == 0)
@property
def defer(self) -> frozenset[str] | None:
if not self.query.deferred_loading[1]:
return None
return frozenset(self.query.deferred_loading[0])
@defer.setter
def defer(self, value: Iterable[str | None] | None):
value = frozenset(v for v in (value or []) if v is not None)
self.query.deferred_loading = (value, True)
@property
def select_related(self) -> DictTree | None:
if not isinstance(self.query.select_related, dict):
return None
return self.query.select_related
@select_related.setter
def select_related(self, value: DictTree | None):
self.query.select_related = value or {}
@property
def prefetch_related(self) -> list[Prefetch | str]:
return list(self.qs._prefetch_related_lookups) # type: ignore
@prefetch_related.setter
def prefetch_related(self, value: Iterable[Prefetch | str] | None):
self.qs._prefetch_related_lookups = tuple(value or []) # type: ignore
@property
def annotations(self) -> dict[str, Expression]:
return self.query.annotations
@annotations.setter
def annotations(self, value: dict[str, Expression] | None):
self.query.annotations = value or {} # type: ignore
@property
def extra(self) -> DictTree:
return self.query.extra
@extra.setter
def extra(self, value: DictTree | None):
self.query.extra = value or {} # type: ignore
@property
def where(self) -> WhereNode:
return self.query.where
@where.setter
def where(self, value: WhereNode | None):
self.query.where = value or WhereNode()
def merge(self, other: PrefetchInspector, *, allow_unsafe_ops: bool = False):
if not allow_unsafe_ops and self.where != other.where:
raise ValueError(
"Tried to prefetch 2 queries with different filters to the "
"same attribute. Use `to_attr` in this case...",
)
# Merge select_related
self.select_related = dicttree_merge(
self.select_related or {},
other.select_related or {},
)
# Merge only/deferred
if not allow_unsafe_ops and (self.defer is None) != (other.defer is None):
raise ValueError(
"Tried to prefetch 2 queries with different deferred "
"operations. Use only `only` or `deferred`, not both...",
)
if self.only is not None and other.only is not None:
self.only |= other.only
elif self.defer is not None and other.defer is not None:
self.defer |= other.defer
else:
# One has defer, the other only. In this case, defer nothing
self.defer = frozenset()
# Merge annotations
s_annotations = self.annotations
o_annotations = other.annotations
if not allow_unsafe_ops:
for k in set(s_annotations) & set(o_annotations):
if s_annotations[k] != o_annotations[k]:
raise ValueError(
"Tried to prefetch 2 queries with overlapping annotations.",
)
self.annotations = {**s_annotations, **o_annotations}
# Merge extra
s_extra = self.extra
o_extra = other.extra
if not allow_unsafe_ops and dicttree_insersection_differs(s_extra, o_extra):
raise ValueError("Tried to prefetch 2 queries with overlapping extras.")
self.extra = {**s_extra, **o_extra}
prefetch_related: dict[str, str | Prefetch] = {}
for p in itertools.chain(self.prefetch_related, other.prefetch_related):
if isinstance(p, str):
if p not in prefetch_related:
prefetch_related[p] = p
continue
path = p.prefetch_to
existing = prefetch_related.get(path)
if not existing or isinstance(existing, str):
prefetch_related[path] = p
continue
inspector = self.__class__(existing).merge(PrefetchInspector(p))
prefetch_related[path] = inspector.prefetch
self.prefetch_related = prefetch_related
return self
|