File: conversions.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (176 lines) | stat: -rw-r--r-- 5,880 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import inspect
from dataclasses import Field, replace
from types import new_class
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Iterable,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)

from apischema.methods import is_method, method_wrapper
from apischema.objects.fields import MISSING_DEFAULT, ObjectField, set_object_fields
from apischema.objects.getters import object_fields, parameters_as_fields
from apischema.type_names import type_name
from apischema.typing import get_type_hints
from apischema.utils import (
    empty_dict,
    substitute_type_vars,
    subtyping_substitution,
    to_pascal_case,
    with_parameters,
)

T = TypeVar("T")


def object_deserialization(
    func: Callable[..., T],
    *input_class_modifiers: Callable[[type], Any],
    parameters_metadata: Optional[Mapping[str, Mapping]] = None,
) -> Any:
    fields = parameters_as_fields(func, parameters_metadata)
    types = get_type_hints(func, include_extras=True)
    if "return" not in types:
        raise TypeError("Object deserialization must be typed")
    return_type = types["return"]
    bases = ()
    if getattr(return_type, "__parameters__", ()):
        bases = (Generic[return_type.__parameters__],)  # type: ignore
    elif func.__name__ != "<lambda>":
        input_class_modifiers = (
            type_name(to_pascal_case(func.__name__)),
            *input_class_modifiers,
        )

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    input_cls = new_class(
        to_pascal_case(func.__name__),
        bases,
        exec_body=lambda ns: ns.update({"__init__": __init__}),
    )
    for modifier in input_class_modifiers:
        modifier(input_cls)
    set_object_fields(input_cls, fields)
    if any(f.additional_properties for f in fields):
        kwargs_param = next(f.name for f in fields if f.additional_properties)

        def wrapper(input):
            kwargs = input.kwargs.copy()
            kwargs.update(kwargs.pop(kwargs_param))
            return func(**kwargs)

    else:

        def wrapper(input):
            return func(**input.kwargs)

    wrapper.__annotations__["input"] = with_parameters(input_cls)
    wrapper.__annotations__["return"] = return_type
    return wrapper


def _fields_and_init(
    cls: type, fields_and_methods: Union[Iterable[Any], Callable[[], Iterable[Any]]]
) -> Tuple[Sequence[ObjectField], Callable[[Any, Any], None]]:
    fields = object_fields(cls, serialization=True)
    output_fields: Dict[str, ObjectField] = {}
    methods = []
    if callable(fields_and_methods):
        fields_and_methods = fields_and_methods()
    for elt in fields_and_methods:
        if elt is ...:
            output_fields.update(fields)
            continue
        if isinstance(elt, tuple):
            elt, metadata = elt
        else:
            metadata = empty_dict
        if not isinstance(metadata, Mapping):
            raise TypeError(f"Invalid metadata {metadata}")
        if isinstance(elt, Field):
            elt = elt.name
        if isinstance(elt, str) and elt in fields:
            elt = fields[elt]
        if is_method(elt):
            elt = method_wrapper(elt)
        if isinstance(elt, ObjectField):
            if metadata:
                output_fields[elt.name] = replace(
                    elt, metadata={**elt.metadata, **metadata}, default=MISSING_DEFAULT
                )
            else:
                output_fields[elt.name] = elt
            continue
        elif callable(elt):
            types = get_type_hints(elt)
            first_param = next(iter(inspect.signature(elt).parameters))
            substitution, _ = subtyping_substitution(types.get(first_param, cls), cls)
            ret = substitute_type_vars(types.get("return", Any), substitution)
            output_fields[elt.__name__] = ObjectField(
                elt.__name__, ret, metadata=metadata
            )
            methods.append((elt, output_fields[elt.__name__]))
        else:
            raise TypeError(f"Invalid serialization member {elt} for class {cls}")

    serialized_methods = [m for m, f in methods if output_fields[f.name] is f]
    serialized_fields = list(
        output_fields.keys() - {m.__name__ for m in serialized_methods}
    )

    def __init__(self, obj):
        for field in serialized_fields:
            setattr(self, field, getattr(obj, field))
        for method in serialized_methods:
            setattr(self, method.__name__, method(obj))

    return tuple(output_fields.values()), __init__


def object_serialization(
    cls: Type[T],
    fields_and_methods: Union[Iterable[Any], Callable[[], Iterable[Any]]],
    *output_class_modifiers: Callable[[type], Any],
) -> Callable[[T], Any]:
    generic, bases = cls, ()
    if getattr(cls, "__parameters__", ()):
        generic = cls[cls.__parameters__]  # type: ignore
        bases = Generic[cls.__parameters__]  # type: ignore
    elif (
        callable(fields_and_methods)
        and fields_and_methods.__name__ != "<lambda>"
        and not getattr(cls, "__parameters__", ())
    ):
        output_class_modifiers = (
            type_name(to_pascal_case(fields_and_methods.__name__)),
            *output_class_modifiers,
        )

    def __init__(self, obj):
        _, new_init = _fields_and_init(cls, fields_and_methods)
        new_init.__annotations__ = {"obj": generic}
        output_cls.__init__ = new_init  # type: ignore
        new_init(self, obj)

    __init__.__annotations__ = {"obj": generic}
    output_cls = new_class(
        f"{cls.__name__}Serialization",
        bases,
        exec_body=lambda ns: ns.update({"__init__": __init__}),
    )
    for modifier in output_class_modifiers:
        modifier(output_cls)
    set_object_fields(output_cls, lambda: _fields_and_init(cls, fields_and_methods)[0])

    return output_cls