1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
import warnings
from enum import Enum
from typing import Optional, TypeVar
import pytest
import strawberry
from strawberry.types.info import Info
def test_enum():
@strawberry.enum
class Locale(Enum):
UNITED_STATES = "en_US"
UK = "en_UK"
AUSTRALIA = "en_AU"
@strawberry.mutation
def set_locale(locale: Locale) -> bool:
_ = locale
return True
argument = set_locale.arguments[0]
# TODO: Remove reference to .__strawberry_definition__ with StrawberryEnumDefinition
assert argument.type is Locale.__strawberry_definition__
def test_forward_reference():
global SearchInput
@strawberry.field
def search(search_input: "SearchInput") -> bool:
_ = search_input
return True
@strawberry.input
class SearchInput:
query: str
argument = search.arguments[0]
assert argument.type is SearchInput
del SearchInput
def test_list():
@strawberry.field
def get_longest_word(words: list[str]) -> str:
_ = words
return "I cheated"
argument = get_longest_word.arguments[0]
assert argument.type == list[str]
def test_literal():
@strawberry.field
def get_name(id_: int) -> str:
_ = id_
return "Lord Buckethead"
argument = get_name.arguments[0]
assert argument.type is int
def test_object():
@strawberry.type
class PersonInput:
proper_noun: bool
@strawberry.field
def get_id(person_input: PersonInput) -> int:
_ = person_input
return 0
argument = get_id.arguments[0]
assert argument.type is PersonInput
def test_optional():
@strawberry.field
def set_age(age: int | None) -> bool:
_ = age
return True
argument = set_age.arguments[0]
assert argument.type == Optional[int]
def test_type_var():
T = TypeVar("T")
@strawberry.field
def set_value(value: T) -> bool:
_ = value
return True
argument = set_value.arguments[0]
assert argument.type == T
ContextType = TypeVar("ContextType")
RootValueType = TypeVar("RootValueType")
class CustomInfo(Info[ContextType, RootValueType]):
"""Subclassed Info type used to test dependency injection."""
@pytest.mark.parametrize(
"annotation",
[CustomInfo, CustomInfo[None, None], Info, Info[None, None]],
)
def test_custom_info(annotation):
"""Test to ensure that subclassed Info does not raise warning."""
with warnings.catch_warnings():
warnings.filterwarnings("error")
def get_info(info) -> bool:
_ = info
return True
get_info.__annotations__["info"] = annotation
get_info_field = strawberry.field(get_info)
assert not get_info_field.arguments # Should have no arguments matched
info_parameter = get_info_field.base_resolver.info_parameter
assert info_parameter is not None
assert info_parameter.name == "info"
|