
|
from collections import defaultdict
from dataclasses import dataclass
from functools import wraps
from inspect import Parameter, signature
from typing import (
Any,
Callable,
Collection,
Dict,
Mapping,
MutableMapping,
NoReturn,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from apischema.cache import CacheAwareDict
from apischema.conversions.conversions import AnyConversion
from apischema.methods import method_registerer
from apischema.ordering import Ordering
from apischema.schemas import Schema
from apischema.types import AnyType, Undefined, UndefinedType
from apischema.typing import generic_mro, get_type_hints, is_type
from apischema.utils import (
get_args2,
get_origin_or_type,
get_origin_or_type2,
substitute_type_vars,
subtyping_substitution,
)
@dataclass(frozen=True)
class SerializedMethod:
func: Callable
alias: str
conversion: Optional[AnyConversion]
error_handler: Optional[Callable]
ordering: Optional[Ordering]
schema: Optional[Schema]
def error_type(self) -> AnyType:
assert self.error_handler is not None
types = get_type_hints(self.error_handler, include_extras=True)
if "return" not in types:
raise TypeError("Error handler must be typed")
return types["return"]
def return_type(self, return_type: AnyType) -> AnyType:
if self.error_handler is not None:
error_type = self.error_type()
if error_type is not NoReturn:
return Union[return_type, error_type]
return return_type
def types(self, owner: AnyType = None) -> Mapping[str, AnyType]:
types = get_type_hints(self.func, include_extras=True)
if "return" not in types:
if is_type(self.func):
types["return"] = self.func
else:
raise TypeError("Function must be typed")
types["return"] = self.return_type(types["return"])
if get_args2(owner):
first_param = next(iter(signature(self.func).parameters))
substitution, _ = subtyping_substitution(
types.get(first_param, get_origin_or_type2(owner)), owner
)
types = {
name: substitute_type_vars(tp, substitution)
for name, tp in types.items()
}
return types
_serialized_methods: MutableMapping[Type, Dict[str, SerializedMethod]] = CacheAwareDict(
defaultdict(dict)
)
S = TypeVar("S", bound=SerializedMethod)
def _get_methods(
tp: AnyType, all_methods: Mapping[Type, Mapping[str, S]]
) -> Collection[Tuple[S, Mapping[str, AnyType]]]:
result = {}
for base in reversed(generic_mro(tp)):
for name, method in all_methods[get_origin_or_type(base)].items():
result[name] = (method, method.types(base))
return result.values()
def get_serialized_methods(
tp: AnyType,
) -> Collection[Tuple[SerializedMethod, Mapping[str, AnyType]]]:
return _get_methods(tp, _serialized_methods)
ErrorHandler = Union[Callable, None, UndefinedType]
def none_error_handler(error: Exception, obj: Any, alias: str) -> None:
return None
MethodOrProp = TypeVar("MethodOrProp", Callable, property)
@overload
def serialized(__method_or_property: MethodOrProp) -> MethodOrProp:
...
@overload
def serialized(
alias: Optional[str] = None,
*,
conversion: Optional[AnyConversion] = None,
error_handler: ErrorHandler = Undefined,
order: Optional[Ordering] = None,
schema: Optional[Schema] = None,
owner: Optional[Type] = None,
) -> Callable[[MethodOrProp], MethodOrProp]:
...
def serialized(
__arg=None,
*,
alias: Optional[str] = None,
conversion: Optional[AnyConversion] = None,
error_handler: ErrorHandler = Undefined,
order: Optional[Ordering] = None,
schema: Optional[Schema] = None,
owner: Optional[Type] = None,
):
def register(func: Callable, owner: Type, alias2: str):
alias2 = alias or alias2
parameters = list(signature(func).parameters.values())
for param in parameters[1:]:
if (
param.kind not in {Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD}
and param.default is Parameter.empty
):
raise TypeError("Serialized method cannot have required parameter")
error_handler2 = error_handler
if error_handler2 is None:
error_handler2 = none_error_handler
if error_handler2 is Undefined:
error_handler2 = None
else:
wrapped = func
@wraps(wrapped)
def func(self):
try:
return wrapped(self)
except Exception as error:
assert (
error_handler2 is not None and error_handler2 is not Undefined
)
return error_handler2(error, self, alias2)
assert not isinstance(error_handler2, UndefinedType)
_serialized_methods[owner][alias2] = SerializedMethod(
func, alias2, conversion, error_handler2, order, schema
)
if isinstance(__arg, str):
alias = __arg
__arg = None
return method_registerer(__arg, owner, register)
|