File: conversions_resolver.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (134 lines) | stat: -rw-r--r-- 4,406 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from contextlib import suppress
from typing import (
    Any,
    Collection,
    Iterable,
    Iterator,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    Union,
)

from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.conversions.visitor import (
    Conv,
    ConversionsVisitor,
    DeserializationVisitor,
    SerializationVisitor,
)
from apischema.types import AnyType
from apischema.utils import is_hashable
from apischema.visitor import Unsupported

try:
    from apischema.typing import Annotated, is_union
except ImportError:
    Annotated = ...  # type: ignore


def merge_results(
    results: Iterable[Sequence[AnyType]], origin: AnyType
) -> Sequence[AnyType]:
    def rec(index=0) -> Iterator[Sequence[AnyType]]:
        if index < len(result_list):
            for next_ in rec(index + 1):
                for res in result_list[index]:
                    yield (res, *next_)
        else:
            yield ()

    result_list = list(results)
    return [(Union if is_union(origin) else origin)[tuple(r)] for r in rec()]


class ConversionsResolver(ConversionsVisitor[Conv, Sequence[AnyType]]):
    def __init__(self, default_conversion: DefaultConversion):
        super().__init__(default_conversion)
        self._skip_conversion = True
        self._rec_guard: Set[Tuple[AnyType, Conv]] = set()

    def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Sequence[AnyType]:
        return [
            Annotated[(res, *annotations)] for res in super().annotated(tp, annotations)
        ]

    def collection(
        self, cls: Type[Collection], value_type: AnyType
    ) -> Sequence[AnyType]:
        return merge_results([self.visit(value_type)], Collection)

    def mapping(
        self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType
    ) -> Sequence[AnyType]:
        return merge_results([self.visit(key_type), self.visit(value_type)], Mapping)

    def new_type(self, tp: AnyType, super_type: AnyType) -> Sequence[AnyType]:
        raise NotImplementedError

    def tuple(self, types: Sequence[AnyType]) -> Sequence[AnyType]:
        return merge_results(map(self.visit, types), Tuple)

    def _visited_union(self, results: Sequence[Sequence[AnyType]]) -> Sequence[AnyType]:
        return merge_results(results, Union)

    def visit_conversion(
        self,
        tp: AnyType,
        conversion: Any,
        dynamic: bool,
        next_conversion: Optional[AnyConversion] = None,
    ) -> Sequence[AnyType]:
        if conversion is not None and self._skip_conversion:
            return [] if dynamic else [tp]
        self._skip_conversion = False
        results: Sequence[AnyType] = []
        if not is_hashable(tp):
            with suppress(NotImplementedError, Unsupported):
                results = super().visit_conversion(
                    tp, conversion, dynamic, next_conversion
                )
        elif (tp, conversion) not in self._rec_guard:
            self._rec_guard.add((tp, conversion))
            with suppress(NotImplementedError, Unsupported):
                results = super().visit_conversion(
                    tp, conversion, dynamic, next_conversion
                )
            self._rec_guard.remove((tp, conversion))
        if not dynamic and (conversion is not None or not results):
            results = [tp, *results]
        return results


class WithConversionsResolver:
    def resolve_conversion(self, tp: AnyType) -> Sequence[AnyType]:
        raise NotImplementedError

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        Resolver: Type[ConversionsResolver]
        if issubclass(cls, DeserializationVisitor):

            class Resolver(ConversionsResolver, DeserializationVisitor):  # type: ignore
                pass

        elif issubclass(cls, SerializationVisitor):

            class Resolver(ConversionsResolver, SerializationVisitor):  # type: ignore
                pass

        else:
            return

        def resolve_conversion(
            self: ConversionsVisitor, tp: AnyType
        ) -> Sequence[AnyType]:
            return Resolver(self.default_conversion).visit_with_conv(
                tp, self._conversion
            )

        assert issubclass(cls, WithConversionsResolver)
        cls.resolve_conversion = resolve_conversion  # type: ignore