1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
import contextlib
import warnings
from typing import TYPE_CHECKING, Any, cast
from asgiref.sync import sync_to_async
from django.contrib.auth.base_user import AbstractBaseUser
from django.test.client import AsyncClient, Client
from strawberry.test import BaseGraphQLTestClient
from strawberry.test.client import Response
from typing_extensions import override
if TYPE_CHECKING:
from collections.abc import Awaitable
class TestClient(BaseGraphQLTestClient):
__test__ = False
def __init__(self, path: str, client: Client | None = None):
self.path = path
super().__init__(client or Client())
@property
def client(self) -> Client:
return self._client
def request(
self,
body: dict[str, object],
headers: dict[str, object] | None = None,
files: dict[str, object] | None = None,
):
kwargs: dict[str, object] = {"data": body, "headers": headers}
if files:
kwargs["format"] = "multipart"
else:
kwargs["content_type"] = "application/json"
return self.client.post(
self.path,
**kwargs, # type: ignore
)
@contextlib.contextmanager
def login(self, user: AbstractBaseUser):
self.client.force_login(user)
yield
self.client.logout()
class AsyncTestClient(TestClient):
def __init__(self, path: str, client: AsyncClient | None = None):
super().__init__(
path,
client or AsyncClient(), # type: ignore
)
@property
def client(self) -> AsyncClient: # type: ignore[reportIncompatibleMethodOverride]
return self._client
@override
async def query(
self,
query: str,
variables: dict[str, Any] | None = None,
headers: dict[str, object] | None = None,
asserts_errors: bool | None = None,
files: dict[str, object] | None = None,
assert_no_errors: bool | None = True,
) -> Response:
body = self._build_body(query, variables, files)
resp = await cast("Awaitable", self.request(body, headers, files))
data = self._decode(resp, type="multipart" if files else "json")
response = Response(
errors=data.get("errors"),
data=data.get("data"),
extensions=data.get("extensions"),
)
if asserts_errors is not None:
warnings.warn(
"The `asserts_errors` argument has been renamed to `assert_no_errors`",
DeprecationWarning,
stacklevel=2,
)
assert_no_errors = (
assert_no_errors if asserts_errors is None else asserts_errors
)
if assert_no_errors:
assert response.errors is None
return response
@contextlib.asynccontextmanager
async def login(self, user: AbstractBaseUser): # type: ignore
await sync_to_async(self.client.force_login)(user)
yield
await sync_to_async(self.client.logout)()
|