1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
from collections import defaultdict
from dataclasses import dataclass
from functools import wraps
from inspect import Parameter, signature
from typing import (
Any,
Callable,
Collection,
Dict,
Mapping,
MutableMapping,
NoReturn,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from apischema.cache import CacheAwareDict
from apischema.conversions.conversions import AnyConversion
from apischema.methods import method_registerer
from apischema.ordering import Ordering
from apischema.schemas import Schema
from apischema.types import AnyType, Undefined, UndefinedType
from apischema.typing import generic_mro, get_type_hints, is_type
from apischema.utils import (
get_args2,
get_origin_or_type,
get_origin_or_type2,
substitute_type_vars,
subtyping_substitution,
)
@dataclass(frozen=True)
class SerializedMethod:
func: Callable
alias: str
conversion: Optional[AnyConversion]
error_handler: Optional[Callable]
ordering: Optional[Ordering]
schema: Optional[Schema]
def error_type(self) -> AnyType:
assert self.error_handler is not None
types = get_type_hints(self.error_handler, include_extras=True)
if "return" not in types:
raise TypeError("Error handler must be typed")
return types["return"]
def return_type(self, return_type: AnyType) -> AnyType:
if self.error_handler is not None:
error_type = self.error_type()
if error_type is not NoReturn:
return Union[return_type, error_type]
return return_type
def types(self, owner: AnyType = None) -> Mapping[str, AnyType]:
types = get_type_hints(self.func, include_extras=True)
if "return" not in types:
if is_type(self.func):
types["return"] = self.func
else:
raise TypeError("Function must be typed")
types["return"] = self.return_type(types["return"])
if get_args2(owner):
first_param = next(iter(signature(self.func).parameters))
substitution, _ = subtyping_substitution(
types.get(first_param, get_origin_or_type2(owner)), owner
)
types = {
name: substitute_type_vars(tp, substitution)
for name, tp in types.items()
}
return types
_serialized_methods: MutableMapping[Type, Dict[str, SerializedMethod]] = CacheAwareDict(
defaultdict(dict)
)
S = TypeVar("S", bound=SerializedMethod)
def _get_methods(
tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]]
) -> Collection[Tuple[S, Mapping[str, AnyType]]]:
result = {}
for base in reversed(generic_mro(tp)):
for name, method in all_methods[get_origin_or_type(base)].items():
result[name] = (method, method.types(base))
return result.values()
def get_serialized_methods(
tp: AnyType,
) -> Collection[Tuple[SerializedMethod, Mapping[str, AnyType]]]:
return _get_methods(tp, _serialized_methods)
ErrorHandler = Union[Callable, None, UndefinedType]
def none_error_handler(error: Exception, obj: Any, alias: str) -> None:
return None
MethodOrProp = TypeVar("MethodOrProp", Callable, property)
@overload
def serialized(__method_or_property: MethodOrProp) -> MethodOrProp:
...
@overload
def serialized(
alias: Optional[str] = None,
*,
conversion: Optional[AnyConversion] = None,
error_handler: ErrorHandler = Undefined,
order: Optional[Ordering] = None,
schema: Optional[Schema] = None,
owner: Optional[Type] = None,
) -> Callable[[MethodOrProp], MethodOrProp]:
...
def serialized(
__arg=None,
*,
alias: Optional[str] = None,
conversion: Optional[AnyConversion] = None,
error_handler: ErrorHandler = Undefined,
order: Optional[Ordering] = None,
schema: Optional[Schema] = None,
owner: Optional[Type] = None,
):
def register(func: Callable, owner: Type, alias2: str):
alias2 = alias or alias2
parameters = list(signature(func).parameters.values())
for param in parameters[1:]:
if (
param.kind not in {Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD}
and param.default is Parameter.empty
):
raise TypeError("Serialized method cannot have required parameter")
error_handler2 = error_handler
if error_handler2 is None:
error_handler2 = none_error_handler
if error_handler2 is Undefined:
error_handler2 = None
else:
wrapped = func
@wraps(wrapped)
def func(self):
try:
return wrapped(self)
except Exception as error:
assert (
error_handler2 is not None and error_handler2 is not Undefined
)
return error_handler2(error, self, alias2)
assert not isinstance(error_handler2, UndefinedType)
_serialized_methods[owner][alias2] = SerializedMethod(
func, alias2, conversion, error_handler2, order, schema
)
if isinstance(__arg, str):
alias = __arg
__arg = None
return method_registerer(__arg, owner, register)
|