1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
from contextlib import suppress
from typing import (
Any,
Collection,
Iterable,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.conversions.visitor import (
Conv,
ConversionsVisitor,
DeserializationVisitor,
SerializationVisitor,
)
from apischema.types import AnyType
from apischema.utils import is_hashable
from apischema.visitor import Unsupported
try:
from apischema.typing import Annotated, is_union
except ImportError:
Annotated = ... # type: ignore
def merge_results(
results: Iterable[Sequence[AnyType]], origin: AnyType
) -> Sequence[AnyType]:
def rec(index=0) -> Iterator[Sequence[AnyType]]:
if index < len(result_list):
for next_ in rec(index + 1):
for res in result_list[index]:
yield (res, *next_)
else:
yield ()
result_list = list(results)
return [(Union if is_union(origin) else origin)[tuple(r)] for r in rec()]
class ConversionsResolver(ConversionsVisitor[Conv, Sequence[AnyType]]):
def __init__(self, default_conversion: DefaultConversion):
super().__init__(default_conversion)
self._skip_conversion = True
self._rec_guard: Set[Tuple[AnyType, Conv]] = set()
def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Sequence[AnyType]:
return [
Annotated[(res, *annotations)] for res in super().annotated(tp, annotations)
]
def collection(
self, cls: Type[Collection], value_type: AnyType
) -> Sequence[AnyType]:
return merge_results([self.visit(value_type)], Collection)
def mapping(
self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType
) -> Sequence[AnyType]:
return merge_results([self.visit(key_type), self.visit(value_type)], Mapping)
def new_type(self, tp: AnyType, super_type: AnyType) -> Sequence[AnyType]:
raise NotImplementedError
def tuple(self, types: Sequence[AnyType]) -> Sequence[AnyType]:
return merge_results(map(self.visit, types), Tuple)
def _visited_union(self, results: Sequence[Sequence[AnyType]]) -> Sequence[AnyType]:
return merge_results(results, Union)
def visit_conversion(
self,
tp: AnyType,
conversion: Any,
dynamic: bool,
next_conversion: Optional[AnyConversion] = None,
) -> Sequence[AnyType]:
if conversion is not None and self._skip_conversion:
return [] if dynamic else [tp]
self._skip_conversion = False
results: Sequence[AnyType] = []
if not is_hashable(tp):
with suppress(NotImplementedError, Unsupported):
results = super().visit_conversion(
tp, conversion, dynamic, next_conversion
)
elif (tp, conversion) not in self._rec_guard:
self._rec_guard.add((tp, conversion))
with suppress(NotImplementedError, Unsupported):
results = super().visit_conversion(
tp, conversion, dynamic, next_conversion
)
self._rec_guard.remove((tp, conversion))
if not dynamic and (conversion is not None or not results):
results = [tp, *results]
return results
class WithConversionsResolver:
def resolve_conversion(self, tp: AnyType) -> Sequence[AnyType]:
raise NotImplementedError
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Resolver: Type[ConversionsResolver]
if issubclass(cls, DeserializationVisitor):
class Resolver(ConversionsResolver, DeserializationVisitor): # type: ignore
pass
elif issubclass(cls, SerializationVisitor):
class Resolver(ConversionsResolver, SerializationVisitor): # type: ignore
pass
else:
return
def resolve_conversion(
self: ConversionsVisitor, tp: AnyType
) -> Sequence[AnyType]:
return Resolver(self.default_conversion).visit_with_conv(
tp, self._conversion
)
assert issubclass(cls, WithConversionsResolver)
cls.resolve_conversion = resolve_conversion # type: ignore
|