1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
from dataclasses import MISSING, Field
from typing import Any, Collection, Mapping, Optional, Sequence
from apischema.aliases import Aliaser, get_class_aliaser
from apischema.conversions.conversions import AnyConversion
from apischema.dataclasses import replace
from apischema.metadata.keys import ALIAS_METADATA
from apischema.objects.fields import MISSING_DEFAULT, FieldKind, ObjectField
from apischema.types import AnyType, Undefined
from apischema.typing import get_args
from apischema.utils import get_origin_or_type, get_parameters, substitute_type_vars
from apischema.visitor import Result, Visitor
def object_field_from_field(
field: Field, field_type: AnyType, init_var: bool
) -> ObjectField:
required = field.default is MISSING and field.default_factory is MISSING
if init_var:
kind = FieldKind.WRITE_ONLY
elif not field.init:
kind = FieldKind.READ_ONLY
else:
kind = FieldKind.NORMAL
return ObjectField(
field.name,
field_type,
required,
field.metadata,
default=field.default,
default_factory=field.default_factory, # type: ignore
kind=kind,
)
def _override_alias(field: ObjectField, aliaser: Aliaser) -> ObjectField:
if field.override_alias:
return replace(
field,
metadata={**field.metadata, ALIAS_METADATA: aliaser(field.alias)},
default=MISSING_DEFAULT,
)
else:
return field
class ObjectVisitor(Visitor[Result]):
_field_kind_filtered: Optional[FieldKind] = None
def _field_conversion(self, field: ObjectField) -> Optional[AnyConversion]:
raise NotImplementedError
def _skip_field(self, field: ObjectField) -> bool:
raise NotImplementedError
@staticmethod
def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
from apischema import settings
return settings.default_object_fields(cls)
def _override_fields(
self, tp: AnyType, fields: Sequence[ObjectField]
) -> Sequence[ObjectField]:
origin = get_origin_or_type(tp)
if isinstance(origin, type):
default_fields = self._default_fields(origin)
if default_fields is not None:
if get_args(tp):
sub = dict(zip(get_parameters(origin), get_args(tp)))
default_fields = [
replace(f, type=substitute_type_vars(f.type, sub))
for f in default_fields
]
return default_fields
return fields
def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
fields = [f for f in fields if not self._skip_field(f)]
if aliaser := get_class_aliaser(get_origin_or_type(tp)):
fields = [_override_alias(f, aliaser) for f in fields]
return self.object(tp, fields)
def dataclass(
self,
tp: AnyType,
types: Mapping[str, AnyType],
fields: Sequence[Field],
init_vars: Sequence[Field],
) -> Result:
by_name = {
f.name: object_field_from_field(f, types[f.name], init_var)
for field_group, init_var in [(fields, False), (init_vars, True)]
for f in field_group
}
object_fields = [
by_name[name]
for name in types
if name in by_name and by_name[name].kind != self._field_kind_filtered
]
return self._object(tp, self._override_fields(tp, object_fields))
def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
raise NotImplementedError
def named_tuple(
self, tp: AnyType, types: Mapping[str, AnyType], defaults: Mapping[str, Any]
) -> Result:
fields = [
ObjectField(name, type_, name not in defaults, default=defaults.get(name))
for name, type_ in types.items()
]
return self._object(tp, self._override_fields(tp, fields))
def typed_dict(
self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str]
) -> Result:
fields = [
ObjectField(name, type_, name in required_keys, default=Undefined)
for name, type_ in types.items()
]
return self._object(tp, self._override_fields(tp, fields))
def unsupported(self, tp: AnyType) -> Result:
dummy: list = []
fields = self._override_fields(tp, dummy)
return super().unsupported(tp) if fields is dummy else self._object(tp, fields)
class DeserializationObjectVisitor(ObjectVisitor[Result]):
_field_kind_filtered = FieldKind.READ_ONLY
@staticmethod
def _field_conversion(field: ObjectField) -> Optional[AnyConversion]:
return field.deserialization
@staticmethod
def _skip_field(field: ObjectField) -> bool:
return field.skip.deserialization
class SerializationObjectVisitor(ObjectVisitor[Result]):
_field_kind_filtered = FieldKind.WRITE_ONLY
@staticmethod
def _field_conversion(field: ObjectField) -> Optional[AnyConversion]:
return field.serialization
@staticmethod
def _skip_field(field: ObjectField) -> bool:
return field.skip.serialization
|