File: inspect.py

package info (click to toggle)
strawberry-graphql-django 0.62.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,968 kB
  • sloc: python: 27,530; sh: 17; makefile: 16
file content (387 lines) | stat: -rw-r--r-- 12,829 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
from __future__ import annotations

import dataclasses
import functools
import itertools
import weakref
from collections.abc import Iterable
from typing import (
    TYPE_CHECKING,
    Any,
    cast,
)

from django.db.models.query import Prefetch, QuerySet
from django.db.models.sql.where import WhereNode
from strawberry import Schema
from strawberry.types import has_object_definition
from strawberry.types.base import (
    StrawberryContainer,
    StrawberryObjectDefinition,
    StrawberryType,
    StrawberryTypeVar,
)
from strawberry.types.lazy_type import LazyType
from strawberry.types.union import StrawberryUnion
from strawberry.utils.str_converters import to_camel_case
from typing_extensions import TypeIs, assert_never

from strawberry_django.fields.types import resolve_model_field_name

from .pyutils import DictTree, dicttree_insersection_differs, dicttree_merge
from .typing import get_django_definition

if TYPE_CHECKING:
    from collections.abc import Generator, Iterable

    from django.db import models
    from django.db.models.expressions import Expression
    from django.db.models.fields import Field
    from django.db.models.fields.reverse_related import ForeignObjectRel
    from django.db.models.sql.query import Query
    from model_utils.managers import (
        InheritanceManagerMixin,
        InheritanceQuerySetMixin,
    )
    from polymorphic.models import PolymorphicModel


@functools.lru_cache
def get_model_fields(
    model: type[models.Model],
    *,
    camel_case: bool = False,
    is_input: bool = False,
    is_filter: bool = False,
) -> dict[str, Field | ForeignObjectRel]:
    """Get a list of model fields from the model."""
    fields = {}
    for f in model._meta.get_fields():
        name = cast(
            "str",
            resolve_model_field_name(f, is_input=is_input, is_filter=is_filter),
        )
        if camel_case:
            name = to_camel_case(name)
        fields[name] = f
    return fields


def get_model_field(
    model: type[models.Model],
    field_name: str,
    *,
    camel_case: bool = False,
    is_input: bool = False,
    is_filter: bool = False,
) -> Field | ForeignObjectRel | None:
    """Get a model fields from the model given its name."""
    return get_model_fields(
        model,
        camel_case=camel_case,
        is_input=is_input,
        is_filter=is_filter,
    ).get(field_name)


def get_possible_types(
    gql_type: StrawberryObjectDefinition | StrawberryType | type,
    *,
    object_definition: StrawberryObjectDefinition | None = None,
) -> Generator[type]:
    """Resolve all possible types for gql_type.

    Args:
    ----
        gql_type:
            The type to retrieve possibilities from.
        type_def:
            Optional type definition to use to resolve type vars.
            This is used internally.

    Yields:
    ------
        All possibilities for the type

    """
    if isinstance(gql_type, StrawberryObjectDefinition):
        yield from get_possible_types(gql_type.origin, object_definition=gql_type)
    elif isinstance(gql_type, LazyType):
        yield from get_possible_types(gql_type.resolve_type())
    elif isinstance(gql_type, StrawberryTypeVar) and object_definition is not None:
        resolved = object_definition.type_var_map.get(gql_type.type_var.__name__, None)
        if resolved is not None:
            yield from get_possible_types(resolved)
    elif isinstance(gql_type, StrawberryContainer):
        yield from get_possible_types(gql_type.of_type)
    elif isinstance(gql_type, StrawberryUnion):
        yield from itertools.chain.from_iterable(
            (get_possible_types(t) for t in gql_type.types),
        )
    elif isinstance(gql_type, StrawberryType):
        # Nothing to return here
        pass
    elif isinstance(gql_type, type):
        yield gql_type
    else:
        assert_never(gql_type)


def get_possible_type_definitions(
    gql_type: StrawberryObjectDefinition | StrawberryType | type,
) -> Generator[StrawberryObjectDefinition]:
    """Resolve all possible type definitions for gql_type.

    Args:
    ----
        gql_type:
            The type to retrieve possibilities from.

    Yields:
    ------
        All possibilities for the type

    """
    if isinstance(gql_type, StrawberryObjectDefinition):
        yield gql_type
        return

    for t in get_possible_types(gql_type):
        if isinstance(t, StrawberryObjectDefinition):
            yield t
        elif has_object_definition(t):
            yield t.__strawberry_definition__


try:
    # Can't import PolymorphicModel, because it requires Django Apps to be ready
    # Import polymorphic instead to check for its existence
    import polymorphic  # noqa: F401

    def is_polymorphic_model(v: type) -> TypeIs[type[PolymorphicModel]]:
        return getattr(v, "polymorphic_model_marker", False) is True

except ImportError:

    def is_polymorphic_model(v: type) -> TypeIs[type[PolymorphicModel]]:
        return False


try:
    from model_utils.managers import InheritanceManagerMixin, InheritanceQuerySetMixin

    def is_inheritance_manager(
        v: Any,
    ) -> TypeIs[InheritanceManagerMixin]:
        return isinstance(v, InheritanceManagerMixin)

    def is_inheritance_qs(
        v: Any,
    ) -> TypeIs[InheritanceQuerySetMixin]:
        return isinstance(v, InheritanceQuerySetMixin)

except ImportError:

    def is_inheritance_manager(
        v: Any,
    ) -> TypeIs[InheritanceManagerMixin]:
        return False

    def is_inheritance_qs(
        v: Any,
    ) -> TypeIs[InheritanceQuerySetMixin]:
        return False


def _can_optimize_subtypes(model: type[models.Model]) -> bool:
    return is_polymorphic_model(model) or is_inheritance_manager(model._default_manager)


_interfaces: weakref.WeakKeyDictionary[
    Schema,
    dict[StrawberryObjectDefinition, list[StrawberryObjectDefinition]],
] = weakref.WeakKeyDictionary()


def get_possible_concrete_types(
    model: type[models.Model],
    schema: Schema,
    strawberry_type: StrawberryObjectDefinition | StrawberryType,
) -> Iterable[StrawberryObjectDefinition]:
    """Return the object definitions the optimizer should look at when optimizing a model.

    Returns any object definitions attached to either the model or one of its supertypes.

    If the model is one that supports polymorphism, by returning subtypes from its queryset, subtypes are also
    looked at. Currently, this is supported for django-polymorphic and django-model-utils InheritanceManager.
    """
    for object_definition in get_possible_type_definitions(strawberry_type):
        if not object_definition.is_interface:
            yield object_definition
            continue

        schema_interfaces = _interfaces.setdefault(schema, {})
        interface_definitions = schema_interfaces.get(object_definition)
        if interface_definitions is None:
            interface_definitions = []
            for t in schema.schema_converter.type_map.values():
                t_definition = t.definition
                if isinstance(t_definition, StrawberryObjectDefinition) and issubclass(
                    t_definition.origin, object_definition.origin
                ):
                    interface_definitions.append(t_definition)

            schema_interfaces[object_definition] = interface_definitions

        for interface_definition in interface_definitions:
            dj_definition = get_django_definition(interface_definition.origin)
            if dj_definition and (
                issubclass(model, dj_definition.model)
                or (
                    _can_optimize_subtypes(model)
                    and issubclass(dj_definition.model, model)
                )
            ):
                yield interface_definition


@dataclasses.dataclass(eq=True)
class PrefetchInspector:
    """Prefetch hints."""

    prefetch: Prefetch
    qs: QuerySet = dataclasses.field(init=False, compare=False)
    query: Query = dataclasses.field(init=False, compare=False)

    def __post_init__(self):
        self.qs = cast("QuerySet", self.prefetch.queryset)  # type: ignore
        self.query = self.qs.query

    @property
    def only(self) -> frozenset[str] | None:
        if self.query.deferred_loading[1]:
            return None
        return frozenset(self.query.deferred_loading[0])

    @only.setter
    def only(self, value: Iterable[str | None] | None):
        value = frozenset(v for v in (value or []) if v is not None)
        self.query.deferred_loading = (value, len(value) == 0)

    @property
    def defer(self) -> frozenset[str] | None:
        if not self.query.deferred_loading[1]:
            return None
        return frozenset(self.query.deferred_loading[0])

    @defer.setter
    def defer(self, value: Iterable[str | None] | None):
        value = frozenset(v for v in (value or []) if v is not None)
        self.query.deferred_loading = (value, True)

    @property
    def select_related(self) -> DictTree | None:
        if not isinstance(self.query.select_related, dict):
            return None
        return self.query.select_related

    @select_related.setter
    def select_related(self, value: DictTree | None):
        self.query.select_related = value or {}

    @property
    def prefetch_related(self) -> list[Prefetch | str]:
        return list(self.qs._prefetch_related_lookups)  # type: ignore

    @prefetch_related.setter
    def prefetch_related(self, value: Iterable[Prefetch | str] | None):
        self.qs._prefetch_related_lookups = tuple(value or [])  # type: ignore

    @property
    def annotations(self) -> dict[str, Expression]:
        return self.query.annotations

    @annotations.setter
    def annotations(self, value: dict[str, Expression] | None):
        self.query.annotations = value or {}  # type: ignore

    @property
    def extra(self) -> DictTree:
        return self.query.extra

    @extra.setter
    def extra(self, value: DictTree | None):
        self.query.extra = value or {}  # type: ignore

    @property
    def where(self) -> WhereNode:
        return self.query.where

    @where.setter
    def where(self, value: WhereNode | None):
        self.query.where = value or WhereNode()

    def merge(self, other: PrefetchInspector, *, allow_unsafe_ops: bool = False):
        if not allow_unsafe_ops and self.where != other.where:
            raise ValueError(
                "Tried to prefetch 2 queries with different filters to the "
                "same attribute. Use `to_attr` in this case...",
            )

        # Merge select_related
        self.select_related = dicttree_merge(
            self.select_related or {},
            other.select_related or {},
        )

        # Merge only/deferred
        if not allow_unsafe_ops and (self.defer is None) != (other.defer is None):
            raise ValueError(
                "Tried to prefetch 2 queries with different deferred "
                "operations. Use only `only` or `deferred`, not both...",
            )
        if self.only is not None and other.only is not None:
            self.only |= other.only
        elif self.defer is not None and other.defer is not None:
            self.defer |= other.defer
        else:
            # One has defer, the other only. In this case, defer nothing
            self.defer = frozenset()

        # Merge annotations
        s_annotations = self.annotations
        o_annotations = other.annotations
        if not allow_unsafe_ops:
            for k in set(s_annotations) & set(o_annotations):
                if s_annotations[k] != o_annotations[k]:
                    raise ValueError(
                        "Tried to prefetch 2 queries with overlapping annotations.",
                    )
        self.annotations = {**s_annotations, **o_annotations}

        # Merge extra
        s_extra = self.extra
        o_extra = other.extra
        if not allow_unsafe_ops and dicttree_insersection_differs(s_extra, o_extra):
            raise ValueError("Tried to prefetch 2 queries with overlapping extras.")
        self.extra = {**s_extra, **o_extra}

        prefetch_related: dict[str, str | Prefetch] = {}
        for p in itertools.chain(self.prefetch_related, other.prefetch_related):
            if isinstance(p, str):
                if p not in prefetch_related:
                    prefetch_related[p] = p
                continue

            path = p.prefetch_to
            existing = prefetch_related.get(path)
            if not existing or isinstance(existing, str):
                prefetch_related[path] = p
                continue

            inspector = self.__class__(existing).merge(PrefetchInspector(p))
            prefetch_related[path] = inspector.prefetch

        self.prefetch_related = prefetch_related

        return self