File: refs.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (171 lines) | stat: -rw-r--r-- 5,638 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from collections import defaultdict
from enum import Enum
from typing import (
    Any,
    Collection,
    Dict,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
)

from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.conversions.visitor import (
    ConversionsVisitor,
    DeserializationVisitor,
    SerializationVisitor,
)
from apischema.discriminators import (
    get_discriminated_parent,
    get_inherited_discriminator,
)
from apischema.json_schema.conversions_resolver import WithConversionsResolver
from apischema.metadata.keys import DISCRIMINATOR_METADATA
from apischema.objects import ObjectField
from apischema.objects.visitor import (
    DeserializationObjectVisitor,
    ObjectVisitor,
    SerializationObjectVisitor,
)
from apischema.type_names import TypeNameFactory, get_type_name
from apischema.types import AnyType
from apischema.utils import get_origin_or_type, is_hashable, replace_builtins
from apischema.visitor import Unsupported

try:
    from apischema.typing import Annotated, get_origin, is_union
except ImportError:
    Annotated = ...  # type: ignore

Refs = Dict[str, Tuple[AnyType, int]]


class Recursive(Exception):
    pass


T = TypeVar("T")


class RefsExtractor(ConversionsVisitor, ObjectVisitor, WithConversionsResolver):
    def __init__(self, default_conversion: DefaultConversion, refs: Refs):
        super().__init__(default_conversion)
        self.refs = refs
        self._rec_guard: Dict[
            Tuple[AnyType, Optional[AnyConversion]], int
        ] = defaultdict(lambda: 0)

    def _incr_ref(self, ref: Optional[str], tp: AnyType) -> bool:
        if ref is None:
            return False
        else:
            ref_cls, count = self.refs.get(ref, (tp, 0))
            if replace_builtins(ref_cls) != replace_builtins(tp):
                raise ValueError(
                    f"Types {tp} and {self.refs[ref][0]} share same reference '{ref}'"
                )
            self.refs[ref] = (ref_cls, count + 1)
            return count > 0

    def annotated(self, tp: AnyType, annotations: Sequence[Any]):
        for i, annotation in enumerate(reversed(annotations)):
            if isinstance(annotation, TypeNameFactory):
                ref = annotation.to_type_name(tp).json_schema
                if not isinstance(ref, str):
                    continue
                ref_annotations = annotations[: len(annotations) - i]
                annotated = Annotated[(tp, *ref_annotations)]
                if self._incr_ref(ref, annotated):
                    return
            if (
                isinstance(annotation, Mapping)
                and DISCRIMINATOR_METADATA in annotation
                and is_union(get_origin(tp))
            ):
                # Visit one more time discriminated union in order to ensure ref count > 1
                self.visit(tp)
        return super().annotated(tp, annotations)

    def any(self):
        pass

    def collection(self, cls: Type[Collection], value_type: AnyType):
        self.visit(value_type)

    def enum(self, cls: Type[Enum]):
        pass

    def literal(self, values: Sequence[Any]):
        pass

    def mapping(self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType):
        self.visit(key_type)
        self.visit(value_type)

    def object(self, tp: AnyType, fields: Sequence[ObjectField]):
        if parent := get_discriminated_parent(get_origin_or_type(tp)):
            self._incr_ref(get_type_name(parent).json_schema, parent)
        for field in fields:
            self.visit_with_conv(field.type, self._field_conversion(field))

    def primitive(self, cls: Type):
        pass

    def tuple(self, types: Sequence[AnyType]):
        for cls in types:
            self.visit(cls)

    def _visited_union(self, results: Sequence):
        pass

    def union(self, types: Sequence[AnyType]):
        super().union(types)
        if get_inherited_discriminator(types):
            # Visit one more time discriminated union in order to ensure ref count > 1
            super().union(types)

    def visit_conversion(
        self,
        tp: AnyType,
        conversion: Optional[Any],
        dynamic: bool,
        next_conversion: Optional[AnyConversion] = None,
    ):
        ref_types = []
        if not dynamic:
            for ref_tp in self.resolve_conversion(tp):
                ref_types.append(ref_tp)
                if self._incr_ref(get_type_name(ref_tp).json_schema, ref_tp):
                    return
        if not is_hashable(tp):
            return super().visit_conversion(tp, conversion, dynamic, next_conversion)
        # 2 because the first type encountered of the recursive cycle can have no ref
        # (see test_recursive_by_conversion_schema)
        if self._rec_guard[(tp, self._conversion)] > 2:
            raise TypeError(
                f"Recursive type {tp} needs a ref. "
                "You can supply one using the type_name() decorator."
            )
        self._rec_guard[(tp, self._conversion)] += 1
        try:
            super().visit_conversion(tp, conversion, dynamic, next_conversion)
        except Unsupported:
            for ref_tp in ref_types:
                self.refs.pop(get_type_name(ref_tp).json_schema, ...)  # type: ignore
        finally:
            self._rec_guard[(tp, self._conversion)] -= 1


class DeserializationRefsExtractor(
    RefsExtractor, DeserializationVisitor, DeserializationObjectVisitor
):
    pass


class SerializationRefsExtractor(
    RefsExtractor, SerializationVisitor, SerializationObjectVisitor
):
    pass