File: test_extensions.py

package info (click to toggle)
strawberry-graphql 0.306.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 18,176 kB
  • sloc: javascript: 178,052; python: 65,643; sh: 33; makefile: 25
file content (193 lines) | stat: -rw-r--r-- 5,554 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from enum import Enum, auto
from typing import Annotated, cast

from graphql import (
    DirectiveLocation,
    GraphQLEnumType,
    GraphQLInputType,
    GraphQLObjectType,
    GraphQLSchema,
)

import strawberry
from strawberry.directive import DirectiveValue
from strawberry.scalars import JSON
from strawberry.schema.schema_converter import GraphQLCoreConverter
from strawberry.schema_directive import Location
from strawberry.types.base import get_object_definition

DEFINITION_BACKREF = GraphQLCoreConverter.DEFINITION_BACKREF


def test_extensions_schema_directive():
    @strawberry.schema_directive(locations=[Location.OBJECT, Location.INPUT_OBJECT])
    class SchemaDirective:
        name: str

    @strawberry.type(directives=[SchemaDirective(name="Query")])
    class Query:
        hello: str

    schema = strawberry.Schema(query=Query)
    graphql_schema: GraphQLSchema = schema._schema

    # Schema
    assert graphql_schema.extensions[DEFINITION_BACKREF] is schema

    # TODO: Apparently I stumbled on a bug:
    #        SchemaDirective are used on schema.__str__(),
    #        but aren't added to graphql_schema.directives
    # maybe graphql_schema_directive = graphql_schema.get_directive("schemaDirective")

    directives = get_object_definition(Query, strict=True).directives
    assert directives is not None
    graphql_schema_directive = schema.schema_converter.from_schema_directive(
        directives[0]
    )
    assert (
        graphql_schema_directive.extensions[DEFINITION_BACKREF]
        is SchemaDirective.__strawberry_directive__
    )


def test_directive():
    @strawberry.directive(locations=[DirectiveLocation.FIELD])
    def uppercase(value: DirectiveValue[str], foo: str):  # pragma: no cover
        return value.upper()

    @strawberry.type()
    class Query:
        hello: str

    schema = strawberry.Schema(query=Query, directives=[uppercase])
    graphql_schema: GraphQLSchema = schema._schema

    graphql_directive = graphql_schema.get_directive("uppercase")
    assert graphql_directive.extensions[DEFINITION_BACKREF] is uppercase
    assert (
        graphql_directive.args["foo"].extensions[DEFINITION_BACKREF]
        is uppercase.arguments[0]
    )


def test_enum():
    @strawberry.enum
    class ThingType(Enum):
        JSON = auto()
        STR = auto()

    @strawberry.type()
    class Query:
        hello: ThingType

    schema = strawberry.Schema(query=Query)
    graphql_schema: GraphQLSchema = schema._schema

    graphql_thing_type = cast("GraphQLEnumType", graphql_schema.get_type("ThingType"))
    assert (
        graphql_thing_type.extensions[DEFINITION_BACKREF]
        is ThingType.__strawberry_definition__
    )
    assert (
        graphql_thing_type.values["JSON"].extensions[DEFINITION_BACKREF]
        is ThingType.__strawberry_definition__.values[0]
    )
    assert (
        graphql_thing_type.values["STR"].extensions[DEFINITION_BACKREF]
        is ThingType.__strawberry_definition__.values[1]
    )


def test_scalar():
    from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY

    @strawberry.type()
    class Query:
        hello: JSON
        hi: str

    schema = strawberry.Schema(query=Query)
    graphql_schema: GraphQLSchema = schema._schema

    assert (
        graphql_schema.get_type("JSON").extensions[DEFINITION_BACKREF]
        is DEFAULT_SCALAR_REGISTRY[JSON]
    )


def test_interface():
    @strawberry.interface
    class Thing:
        name: str

    @strawberry.type()
    class Query:
        hello: Thing

    schema = strawberry.Schema(query=Query)
    graphql_schema: GraphQLSchema = schema._schema

    assert (
        graphql_schema.get_type("Thing").extensions[DEFINITION_BACKREF]
        is Thing.__strawberry_definition__
    )


def test_union():
    @strawberry.type
    class JsonThing:
        value: JSON

    @strawberry.type
    class StrThing:
        value: str

    SomeThing = Annotated[JsonThing | StrThing, strawberry.union("SomeThing")]

    @strawberry.type()
    class Query:
        hello: SomeThing

    schema = strawberry.Schema(query=Query)
    graphql_schema: GraphQLSchema = schema._schema
    graphql_type = graphql_schema.get_type("SomeThing")

    assert graphql_type.extensions[DEFINITION_BACKREF].graphql_name == "SomeThing"
    assert graphql_type.extensions[DEFINITION_BACKREF].description is None


def test_object_types():
    @strawberry.input
    class Input:
        name: str

    @strawberry.type()
    class Query:
        @strawberry.field
        def hello(self, input: Input) -> str: ...

    schema = strawberry.Schema(query=Query)
    graphql_schema: GraphQLSchema = schema._schema

    assert (
        graphql_schema.get_type("Input").extensions[DEFINITION_BACKREF]
        is Input.__strawberry_definition__
    )
    assert (
        graphql_schema.get_type("Query").extensions[DEFINITION_BACKREF]
        is Query.__strawberry_definition__
    )

    graphql_query = cast("GraphQLObjectType", graphql_schema.get_type("Query"))
    assert graphql_query.fields["hello"].extensions[
        DEFINITION_BACKREF
    ] is Query.__strawberry_definition__.get_field("hello")
    assert (
        graphql_query.fields["hello"].args["input"].extensions[DEFINITION_BACKREF]
        is Query.__strawberry_definition__.get_field("hello").arguments[0]
    )

    graphql_input = cast("GraphQLInputType", graphql_schema.get_type("Input"))
    assert graphql_input.fields["name"].extensions[
        DEFINITION_BACKREF
    ] is Input.__strawberry_definition__.get_field("name")