File: test_object_types.py

package info (click to toggle)
strawberry-graphql 0.306.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 18,176 kB
  • sloc: javascript: 178,052; python: 65,643; sh: 33; makefile: 25
file content (235 lines) | stat: -rw-r--r-- 5,066 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# type: ignore
import dataclasses
import re
from enum import Enum
from typing import Annotated, Optional, TypeVar

import pytest

import strawberry
from strawberry.types.base import get_object_definition, has_object_definition
from strawberry.types.field import StrawberryField


def test_enum():
    @strawberry.enum
    class Count(Enum):
        TWO = "two"
        FOUR = "four"

    @strawberry.type
    class Animal:
        legs: Count

    field: StrawberryField = get_object_definition(Animal).fields[0]

    # TODO: Remove reference to .__strawberry_definition__ with StrawberryEnumDefinition
    assert field.type is Count.__strawberry_definition__


def test_forward_reference():
    global FromTheFuture

    @strawberry.type
    class TimeTraveler:
        origin: "FromTheFuture"

    @strawberry.type
    class FromTheFuture:
        year: int

    field: StrawberryField = get_object_definition(TimeTraveler).fields[0]

    assert field.type is FromTheFuture

    del FromTheFuture


def test_list():
    @strawberry.type
    class Santa:
        making_a: list[str]

    field: StrawberryField = get_object_definition(Santa).fields[0]

    assert field.type == list[str]


def test_literal():
    @strawberry.type
    class Fabric:
        thread_type: str

    field: StrawberryField = get_object_definition(Fabric).fields[0]

    assert field.type is str


def test_object():
    @strawberry.type
    class Object:
        proper_noun: bool

    @strawberry.type
    class TransitiveVerb:
        subject: Object

    field: StrawberryField = get_object_definition(TransitiveVerb).fields[0]

    assert field.type is Object


def test_optional():
    @strawberry.type
    class HasChoices:
        decision: bool | None

    field: StrawberryField = get_object_definition(HasChoices).fields[0]

    assert field.type == Optional[bool]


def test_type_var():
    T = TypeVar("T")

    @strawberry.type
    class Gossip:
        spill_the: T

    field: StrawberryField = get_object_definition(Gossip).fields[0]

    assert field.type == T


def test_union():
    @strawberry.type
    class Europe:
        name: str

    @strawberry.type
    class UK:
        name: str

    EU = Annotated[Europe | UK, strawberry.union("EU")]

    @strawberry.type
    class WishfulThinking:
        desire: EU

    field: StrawberryField = get_object_definition(WishfulThinking).fields[0]

    assert field.type == EU


def test_fields_with_defaults():
    @strawberry.type
    class Country:
        name: str = "United Kingdom"
        currency_code: str

    country = Country(currency_code="GBP")
    assert country.name == "United Kingdom"
    assert country.currency_code == "GBP"

    country = Country(name="United States of America", currency_code="USD")
    assert country.name == "United States of America"
    assert country.currency_code == "USD"


def test_fields_with_defaults_inheritance():
    @strawberry.interface
    class A:
        text: str
        delay: int | None = None

    @strawberry.type
    class B(A):
        attachments: list[A] | None = None

    @strawberry.type
    class C(A):
        fields: list[B]

    c_inst = C(
        text="some text",
        fields=[B(text="more text")],
    )

    assert dataclasses.asdict(c_inst) == {
        "text": "some text",
        "delay": None,
        "fields": [
            {
                "text": "more text",
                "attachments": None,
                "delay": None,
            }
        ],
    }


def test_positional_args_not_allowed():
    @strawberry.type
    class Thing:
        name: str

    with pytest.raises(
        TypeError,
        match=re.escape("__init__() takes 1 positional argument but 2 were given"),
    ):
        Thing("something")


def test_object_preserves_annotations():
    @strawberry.type
    class Object:
        a: bool
        b: Annotated[str, "something"]
        c: bool = strawberry.field(graphql_type=int)
        d: Annotated[str, "something"] = strawberry.field(graphql_type=int)

    assert Object.__annotations__ == {
        "a": bool,
        "b": Annotated[str, "something"],
        "c": bool,
        "d": Annotated[str, "something"],
    }


def test_has_object_definition_returns_true_for_object_type():
    @strawberry.type
    class Palette:
        name: str

    assert has_object_definition(Palette) is True


def test_has_object_definition_returns_false_for_enum():
    @strawberry.enum
    class Color(Enum):
        RED = "red"
        GREEN = "green"

    assert has_object_definition(Color) is False


def test_has_object_definition_returns_true_for_interface():
    @strawberry.interface
    class Node:
        id: str

    assert has_object_definition(Node) is True


def test_has_object_definition_returns_true_for_input():
    @strawberry.input
    class CreateUserInput:
        name: str

    assert has_object_definition(CreateUserInput) is True


def test_has_object_definition_returns_false_for_scalar():
    from strawberry.scalars import JSON

    assert has_object_definition(JSON) is False