1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
|
import copy
import inspect
from typing import List, Tuple, Any, Optional
import typeguard
from marshmallow import fields, Schema, ValidationError
try:
from typeguard import TypeCheckError # type: ignore[attr-defined]
except ImportError:
# typeguard < 3
TypeCheckError = TypeError # type: ignore[misc, assignment]
if "argname" not in inspect.signature(typeguard.check_type).parameters:
def _check_type(value, expected_type, argname: str):
return typeguard.check_type(value=value, expected_type=expected_type)
else:
# typeguard < 3.0.0rc2
def _check_type(value, expected_type, argname: str):
return typeguard.check_type( # type: ignore[call-overload]
value=value, expected_type=expected_type, argname=argname
)
class Union(fields.Field):
"""A union field, composed other `Field` classes or instances.
This field serializes elements based on their type, with one of its child fields.
Example: ::
number_or_string = UnionField([
(float, fields.Float()),
(str, fields.Str())
])
:param union_fields: A list of types and their associated field instance.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs):
super().__init__(**kwargs)
self.union_fields = union_fields
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
super()._bind_to_schema(field_name, schema)
new_union_fields = []
for typ, field in self.union_fields:
field = copy.deepcopy(field)
field._bind_to_schema(field_name, self)
new_union_fields.append((typ, field))
self.union_fields = new_union_fields
def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any:
errors = []
if value is None:
return value
for typ, field in self.union_fields:
try:
_check_type(value=value, expected_type=typ, argname=attr or "anonymous")
return field._serialize(value, attr, obj, **kwargs)
except TypeCheckError as e:
errors.append(e)
raise TypeError(
f"Unable to serialize value with any of the fields in the union: {errors}"
)
def _deserialize(self, value: Any, attr: Optional[str], data, **kwargs) -> Any:
errors = []
for typ, field in self.union_fields:
try:
result = field.deserialize(value, **kwargs)
_check_type(
value=result, expected_type=typ, argname=attr or "anonymous"
)
return result
except (TypeCheckError, ValidationError) as e:
errors.append(e)
raise ValidationError(errors)
|