1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
from collections import defaultdict
from enum import Enum
from typing import (
Any,
Collection,
Dict,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.conversions.visitor import (
ConversionsVisitor,
DeserializationVisitor,
SerializationVisitor,
)
from apischema.discriminators import (
get_discriminated_parent,
get_inherited_discriminator,
)
from apischema.json_schema.conversions_resolver import WithConversionsResolver
from apischema.metadata.keys import DISCRIMINATOR_METADATA
from apischema.objects import ObjectField
from apischema.objects.visitor import (
DeserializationObjectVisitor,
ObjectVisitor,
SerializationObjectVisitor,
)
from apischema.type_names import TypeNameFactory, get_type_name
from apischema.types import AnyType
from apischema.utils import get_origin_or_type, is_hashable, replace_builtins
from apischema.visitor import Unsupported
try:
from apischema.typing import Annotated, get_origin, is_union
except ImportError:
Annotated = ... # type: ignore
Refs = Dict[str, Tuple[AnyType, int]]
class Recursive(Exception):
pass
T = TypeVar("T")
class RefsExtractor(ConversionsVisitor, ObjectVisitor, WithConversionsResolver):
def __init__(self, default_conversion: DefaultConversion, refs: Refs):
super().__init__(default_conversion)
self.refs = refs
self._rec_guard: Dict[
Tuple[AnyType, Optional[AnyConversion]], int
] = defaultdict(lambda: 0)
def _incr_ref(self, ref: Optional[str], tp: AnyType) -> bool:
if ref is None:
return False
else:
ref_cls, count = self.refs.get(ref, (tp, 0))
if replace_builtins(ref_cls) != replace_builtins(tp):
raise ValueError(
f"Types {tp} and {self.refs[ref][0]} share same reference '{ref}'"
)
self.refs[ref] = (ref_cls, count + 1)
return count > 0
def annotated(self, tp: AnyType, annotations: Sequence[Any]):
for i, annotation in enumerate(reversed(annotations)):
if isinstance(annotation, TypeNameFactory):
ref = annotation.to_type_name(tp).json_schema
if not isinstance(ref, str):
continue
ref_annotations = annotations[: len(annotations) - i]
annotated = Annotated[(tp, *ref_annotations)]
if self._incr_ref(ref, annotated):
return
if (
isinstance(annotation, Mapping)
and DISCRIMINATOR_METADATA in annotation
and is_union(get_origin(tp))
):
# Visit one more time discriminated union in order to ensure ref count > 1
self.visit(tp)
return super().annotated(tp, annotations)
def any(self):
pass
def collection(self, cls: Type[Collection], value_type: AnyType):
self.visit(value_type)
def enum(self, cls: Type[Enum]):
pass
def literal(self, values: Sequence[Any]):
pass
def mapping(self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType):
self.visit(key_type)
self.visit(value_type)
def object(self, tp: AnyType, fields: Sequence[ObjectField]):
if parent := get_discriminated_parent(get_origin_or_type(tp)):
self._incr_ref(get_type_name(parent).json_schema, parent)
for field in fields:
self.visit_with_conv(field.type, self._field_conversion(field))
def primitive(self, cls: Type):
pass
def tuple(self, types: Sequence[AnyType]):
for cls in types:
self.visit(cls)
def _visited_union(self, results: Sequence):
pass
def union(self, types: Sequence[AnyType]):
super().union(types)
if get_inherited_discriminator(types):
# Visit one more time discriminated union in order to ensure ref count > 1
super().union(types)
def visit_conversion(
self,
tp: AnyType,
conversion: Optional[Any],
dynamic: bool,
next_conversion: Optional[AnyConversion] = None,
):
ref_types = []
if not dynamic:
for ref_tp in self.resolve_conversion(tp):
ref_types.append(ref_tp)
if self._incr_ref(get_type_name(ref_tp).json_schema, ref_tp):
return
if not is_hashable(tp):
return super().visit_conversion(tp, conversion, dynamic, next_conversion)
# 2 because the first type encountered of the recursive cycle can have no ref
# (see test_recursive_by_conversion_schema)
if self._rec_guard[(tp, self._conversion)] > 2:
raise TypeError(
f"Recursive type {tp} needs a ref. "
"You can supply one using the type_name() decorator."
)
self._rec_guard[(tp, self._conversion)] += 1
try:
super().visit_conversion(tp, conversion, dynamic, next_conversion)
except Unsupported:
for ref_tp in ref_types:
self.refs.pop(get_type_name(ref_tp).json_schema, ...) # type: ignore
finally:
self._rec_guard[(tp, self._conversion)] -= 1
class DeserializationRefsExtractor(
RefsExtractor, DeserializationVisitor, DeserializationObjectVisitor
):
pass
class SerializationRefsExtractor(
RefsExtractor, SerializationVisitor, SerializationObjectVisitor
):
pass
|