File: client.py

package info (click to toggle)
strawberry-graphql-django 0.67.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,988 kB
  • sloc: python: 27,682; sh: 20; makefile: 20
file content (102 lines) | stat: -rw-r--r-- 3,043 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import contextlib
import warnings
from typing import TYPE_CHECKING, Any, cast

from asgiref.sync import sync_to_async
from django.contrib.auth.base_user import AbstractBaseUser
from django.test.client import AsyncClient, Client
from strawberry.test import BaseGraphQLTestClient
from strawberry.test.client import Response
from typing_extensions import override

if TYPE_CHECKING:
    from collections.abc import Awaitable


class TestClient(BaseGraphQLTestClient):
    __test__ = False

    def __init__(self, path: str, client: Client | None = None):
        self.path = path
        super().__init__(client or Client())

    @property
    def client(self) -> Client:
        return self._client

    def request(
        self,
        body: dict[str, object],
        headers: dict[str, object] | None = None,
        files: dict[str, object] | None = None,
    ):
        kwargs: dict[str, object] = {"data": body, "headers": headers}
        if files:
            kwargs["format"] = "multipart"
        else:
            kwargs["content_type"] = "application/json"

        return self.client.post(
            self.path,
            **kwargs,  # type: ignore
        )

    @contextlib.contextmanager
    def login(self, user: AbstractBaseUser):
        self.client.force_login(user)
        yield
        self.client.logout()


class AsyncTestClient(TestClient):
    def __init__(self, path: str, client: AsyncClient | None = None):
        super().__init__(
            path,
            client or AsyncClient(),  # type: ignore
        )

    @property
    def client(self) -> AsyncClient:  # type: ignore[reportIncompatibleMethodOverride]
        return self._client

    @override
    async def query(
        self,
        query: str,
        variables: dict[str, Any] | None = None,
        headers: dict[str, object] | None = None,
        asserts_errors: bool | None = None,
        files: dict[str, object] | None = None,
        assert_no_errors: bool | None = True,
    ) -> Response:
        body = self._build_body(query, variables, files)

        resp = await cast("Awaitable", self.request(body, headers, files))
        data = self._decode(resp, type="multipart" if files else "json")

        response = Response(
            errors=data.get("errors"),
            data=data.get("data"),
            extensions=data.get("extensions"),
        )

        if asserts_errors is not None:
            warnings.warn(
                "The `asserts_errors` argument has been renamed to `assert_no_errors`",
                DeprecationWarning,
                stacklevel=2,
            )

        assert_no_errors = (
            assert_no_errors if asserts_errors is None else asserts_errors
        )
        if assert_no_errors:
            assert response.errors is None

        return response

    @contextlib.asynccontextmanager
    async def login(self, user: AbstractBaseUser):  # type: ignore
        await sync_to_async(self.client.force_login)(user)
        yield
        await sync_to_async(self.client.logout)()