1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
|
# type: ignore
import dataclasses
import re
from enum import Enum
from typing import Annotated, Optional, TypeVar
import pytest
import strawberry
from strawberry.types.base import get_object_definition, has_object_definition
from strawberry.types.field import StrawberryField
def test_enum():
@strawberry.enum
class Count(Enum):
TWO = "two"
FOUR = "four"
@strawberry.type
class Animal:
legs: Count
field: StrawberryField = get_object_definition(Animal).fields[0]
# TODO: Remove reference to .__strawberry_definition__ with StrawberryEnumDefinition
assert field.type is Count.__strawberry_definition__
def test_forward_reference():
global FromTheFuture
@strawberry.type
class TimeTraveler:
origin: "FromTheFuture"
@strawberry.type
class FromTheFuture:
year: int
field: StrawberryField = get_object_definition(TimeTraveler).fields[0]
assert field.type is FromTheFuture
del FromTheFuture
def test_list():
@strawberry.type
class Santa:
making_a: list[str]
field: StrawberryField = get_object_definition(Santa).fields[0]
assert field.type == list[str]
def test_literal():
@strawberry.type
class Fabric:
thread_type: str
field: StrawberryField = get_object_definition(Fabric).fields[0]
assert field.type is str
def test_object():
@strawberry.type
class Object:
proper_noun: bool
@strawberry.type
class TransitiveVerb:
subject: Object
field: StrawberryField = get_object_definition(TransitiveVerb).fields[0]
assert field.type is Object
def test_optional():
@strawberry.type
class HasChoices:
decision: bool | None
field: StrawberryField = get_object_definition(HasChoices).fields[0]
assert field.type == Optional[bool]
def test_type_var():
T = TypeVar("T")
@strawberry.type
class Gossip:
spill_the: T
field: StrawberryField = get_object_definition(Gossip).fields[0]
assert field.type == T
def test_union():
@strawberry.type
class Europe:
name: str
@strawberry.type
class UK:
name: str
EU = Annotated[Europe | UK, strawberry.union("EU")]
@strawberry.type
class WishfulThinking:
desire: EU
field: StrawberryField = get_object_definition(WishfulThinking).fields[0]
assert field.type == EU
def test_fields_with_defaults():
@strawberry.type
class Country:
name: str = "United Kingdom"
currency_code: str
country = Country(currency_code="GBP")
assert country.name == "United Kingdom"
assert country.currency_code == "GBP"
country = Country(name="United States of America", currency_code="USD")
assert country.name == "United States of America"
assert country.currency_code == "USD"
def test_fields_with_defaults_inheritance():
@strawberry.interface
class A:
text: str
delay: int | None = None
@strawberry.type
class B(A):
attachments: list[A] | None = None
@strawberry.type
class C(A):
fields: list[B]
c_inst = C(
text="some text",
fields=[B(text="more text")],
)
assert dataclasses.asdict(c_inst) == {
"text": "some text",
"delay": None,
"fields": [
{
"text": "more text",
"attachments": None,
"delay": None,
}
],
}
def test_positional_args_not_allowed():
@strawberry.type
class Thing:
name: str
with pytest.raises(
TypeError,
match=re.escape("__init__() takes 1 positional argument but 2 were given"),
):
Thing("something")
def test_object_preserves_annotations():
@strawberry.type
class Object:
a: bool
b: Annotated[str, "something"]
c: bool = strawberry.field(graphql_type=int)
d: Annotated[str, "something"] = strawberry.field(graphql_type=int)
assert Object.__annotations__ == {
"a": bool,
"b": Annotated[str, "something"],
"c": bool,
"d": Annotated[str, "something"],
}
def test_has_object_definition_returns_true_for_object_type():
@strawberry.type
class Palette:
name: str
assert has_object_definition(Palette) is True
def test_has_object_definition_returns_false_for_enum():
@strawberry.enum
class Color(Enum):
RED = "red"
GREEN = "green"
assert has_object_definition(Color) is False
def test_has_object_definition_returns_true_for_interface():
@strawberry.interface
class Node:
id: str
assert has_object_definition(Node) is True
def test_has_object_definition_returns_true_for_input():
@strawberry.input
class CreateUserInput:
name: str
assert has_object_definition(CreateUserInput) is True
def test_has_object_definition_returns_false_for_scalar():
from strawberry.scalars import JSON
assert has_object_definition(JSON) is False
|