1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
from __future__ import annotations
import dataclasses
import sys
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
ForwardRef,
TypeVar,
Union,
_AnnotatedAlias, # type: ignore
cast,
get_args,
overload,
)
from django.db.models.expressions import BaseExpression, Combinable
from graphql.type.definition import GraphQLResolveInfo
from strawberry.annotation import StrawberryAnnotation
from strawberry.types.auto import StrawberryAuto
from strawberry.types.base import (
StrawberryContainer,
StrawberryType,
WithStrawberryObjectDefinition,
)
from strawberry.types.lazy_type import LazyType, StrawberryLazyReference
from strawberry.utils.typing import is_classvar
from typing_extensions import Protocol, get_annotations
if TYPE_CHECKING:
from typing import Literal, TypeAlias, TypeGuard
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from django.db.models import Prefetch
from strawberry_django.type import StrawberryDjangoDefinition
_T = TypeVar("_T")
_Type = TypeVar("_Type", bound="StrawberryType | type")
TypeOrSequence: TypeAlias = _T | Sequence[_T]
TypeOrMapping: TypeAlias = _T | Mapping[str, _T]
TypeOrIterable: TypeAlias = _T | Iterable[_T]
UserType: TypeAlias = Union["AbstractBaseUser", "AnonymousUser"]
PrefetchCallable: TypeAlias = Callable[[GraphQLResolveInfo], "Prefetch[Any]"]
PrefetchType: TypeAlias = Union[str, "Prefetch[Any]", PrefetchCallable]
AnnotateCallable: TypeAlias = Callable[
[GraphQLResolveInfo],
BaseExpression | Combinable,
]
AnnotateType: TypeAlias = BaseExpression | Combinable | AnnotateCallable
class WithStrawberryDjangoObjectDefinition(WithStrawberryObjectDefinition, Protocol):
__strawberry_django_definition__: ClassVar[StrawberryDjangoDefinition]
def has_django_definition(
obj: Any,
) -> TypeGuard[type[WithStrawberryDjangoObjectDefinition]]:
return hasattr(obj, "__strawberry_django_definition__")
@overload
def get_django_definition(
obj: Any,
*,
strict: Literal[True],
) -> StrawberryDjangoDefinition: ...
@overload
def get_django_definition(
obj: Any,
*,
strict: bool = False,
) -> StrawberryDjangoDefinition | None: ...
def get_django_definition(
obj: Any,
*,
strict: bool = False,
) -> StrawberryDjangoDefinition | None:
return (
obj.__strawberry_django_definition__
if strict
else getattr(obj, "__strawberry_django_definition__", None)
)
def is_auto(obj: Any) -> bool:
if isinstance(obj, ForwardRef):
obj = obj.__forward_arg__
if isinstance(obj, str):
return obj in {"auto", "strawberry.auto"}
return isinstance(obj, StrawberryAuto)
def get_strawberry_annotations(cls) -> dict[str, StrawberryAnnotation]:
annotations: dict[str, StrawberryAnnotation] = {}
for c in reversed(cls.__mro__):
# Skip non dataclass bases other than cls itself
if c is not cls and not dataclasses.is_dataclass(c):
continue
namespace = sys.modules[c.__module__].__dict__
for k, v in get_annotations(c).items():
if not is_classvar(cast("type", c), v):
annotations[k] = StrawberryAnnotation(v, namespace=namespace)
return annotations
@overload
def unwrap_type(type_: StrawberryContainer) -> type: ...
@overload
def unwrap_type(type_: LazyType) -> type: ...
@overload
def unwrap_type(type_: None) -> None: ...
@overload
def unwrap_type(type_: _Type) -> _Type: ...
def unwrap_type(type_):
while True:
if isinstance(type_, LazyType):
type_ = type_.resolve_type()
elif isinstance(type_, StrawberryContainer):
type_ = type_.of_type
else:
break
return type_
def get_type_from_lazy_annotation(type_: _AnnotatedAlias) -> type | None:
first, *rest = get_args(type_)
for arg in rest:
if isinstance(arg, StrawberryLazyReference):
return unwrap_type(arg.resolve_forward_ref(first))
return None
|