1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
import operator
import re
import sys
from base64 import b64decode, b64encode
from collections import deque
from datetime import date, datetime, time
from decimal import Decimal
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from pathlib import (
Path,
PosixPath,
PurePath,
PurePosixPath,
PureWindowsPath,
WindowsPath,
)
from typing import TypeVar
from uuid import UUID
from apischema import ValidationError, deserializer, schema, serializer, type_name
from apischema.conversions import Conversion, as_str, catch_value_error
# =================== bytes =====================
deserializer(Conversion(b64decode, source=str, target=bytes))
@serializer
def to_base64(b: bytes) -> str:
return b64encode(b).decode()
type_name(graphql="Bytes")(bytes)
schema(encoding="base64")(bytes)
# ================ collections ==================
T = TypeVar("T")
if sys.version_info >= (3, 10):
deserializer(Conversion(deque, source=list[T], target=deque[T])) # type: ignore
serializer(Conversion(list, source=deque[T], target=list[T])) # type: ignore
else:
from typing import Deque, List
deserializer(Conversion(deque, source=List[T], target=Deque[T]))
serializer(Conversion(list, source=Deque[T], target=List[T]))
# ================== datetime ===================
for cls, format in [(date, "date"), (datetime, "date-time"), (time, "time")]:
fromisoformat = catch_value_error(cls.fromisoformat) # type: ignore
deserializer(Conversion(fromisoformat, source=str, target=cls))
serializer(Conversion(cls.isoformat, source=cls, target=str)) # type: ignore
type_name(graphql=cls.__name__.capitalize())(cls)
schema(format=format)(cls)
# ================== decimal ====================
deserializer(Conversion(catch_value_error(Decimal), source=float, target=Decimal))
serializer(Conversion(float, source=Decimal, target=float))
type_name(None)(Decimal)
# ================= ipaddress ===================
for classes, format in [
((IPv4Address, IPv4Interface, IPv4Network), "ipv4"),
((IPv6Address, IPv6Interface, IPv6Network), "ipv6"),
]:
for cls in classes:
as_str(cls)
type_name(graphql=cls.__name__)(cls)
schema(format=format)(cls)
# ==================== path =====================
for cls in (PurePath, PurePosixPath, PureWindowsPath, Path, PosixPath, WindowsPath):
as_str(cls)
type_name(None)(cls)
# =================== pattern ===================
@deserializer
def _compile(pattern: str) -> re.Pattern:
try:
return re.compile(pattern)
except re.error as err:
raise ValidationError(str(err))
serializer(Conversion(operator.attrgetter("pattern"), source=re.Pattern, target=str))
type_name(None)(re.Pattern)
# ==================== uuid =====================
as_str(UUID)
type_name(graphql="UUID")
schema(format="uuid")(UUID)
|