File: visitor.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (151 lines) | stat: -rw-r--r-- 5,264 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from dataclasses import MISSING, Field
from typing import Any, Collection, Mapping, Optional, Sequence

from apischema.aliases import Aliaser, get_class_aliaser
from apischema.conversions.conversions import AnyConversion
from apischema.dataclasses import replace
from apischema.metadata.keys import ALIAS_METADATA
from apischema.objects.fields import MISSING_DEFAULT, FieldKind, ObjectField
from apischema.types import AnyType, Undefined
from apischema.typing import get_args
from apischema.utils import get_origin_or_type, get_parameters, substitute_type_vars
from apischema.visitor import Result, Visitor


def object_field_from_field(
    field: Field, field_type: AnyType, init_var: bool
) -> ObjectField:
    required = field.default is MISSING and field.default_factory is MISSING
    if init_var:
        kind = FieldKind.WRITE_ONLY
    elif not field.init:
        kind = FieldKind.READ_ONLY
    else:
        kind = FieldKind.NORMAL
    return ObjectField(
        field.name,
        field_type,
        required,
        field.metadata,
        default=field.default,
        default_factory=field.default_factory,  # type: ignore
        kind=kind,
    )


def _override_alias(field: ObjectField, aliaser: Aliaser) -> ObjectField:
    if field.override_alias:
        return replace(
            field,
            metadata={**field.metadata, ALIAS_METADATA: aliaser(field.alias)},
            default=MISSING_DEFAULT,
        )
    else:
        return field


class ObjectVisitor(Visitor[Result]):
    _field_kind_filtered: Optional[FieldKind] = None

    def _field_conversion(self, field: ObjectField) -> Optional[AnyConversion]:
        raise NotImplementedError

    def _skip_field(self, field: ObjectField) -> bool:
        raise NotImplementedError

    @staticmethod
    def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
        from apischema import settings

        return settings.default_object_fields(cls)

    def _override_fields(
        self, tp: AnyType, fields: Sequence[ObjectField]
    ) -> Sequence[ObjectField]:
        origin = get_origin_or_type(tp)
        if isinstance(origin, type):
            default_fields = self._default_fields(origin)
            if default_fields is not None:
                if get_args(tp):
                    sub = dict(zip(get_parameters(origin), get_args(tp)))
                    default_fields = [
                        replace(f, type=substitute_type_vars(f.type, sub))
                        for f in default_fields
                    ]
                return default_fields
        return fields

    def _object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
        fields = [f for f in fields if not self._skip_field(f)]
        if aliaser := get_class_aliaser(get_origin_or_type(tp)):
            fields = [_override_alias(f, aliaser) for f in fields]
        return self.object(tp, fields)

    def dataclass(
        self,
        tp: AnyType,
        types: Mapping[str, AnyType],
        fields: Sequence[Field],
        init_vars: Sequence[Field],
    ) -> Result:
        by_name = {
            f.name: object_field_from_field(f, types[f.name], init_var)
            for field_group, init_var in [(fields, False), (init_vars, True)]
            for f in field_group
        }
        object_fields = [
            by_name[name]
            for name in types
            if name in by_name and by_name[name].kind != self._field_kind_filtered
        ]
        return self._object(tp, self._override_fields(tp, object_fields))

    def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Result:
        raise NotImplementedError

    def named_tuple(
        self, tp: AnyType, types: Mapping[str, AnyType], defaults: Mapping[str, Any]
    ) -> Result:
        fields = [
            ObjectField(name, type_, name not in defaults, default=defaults.get(name))
            for name, type_ in types.items()
        ]
        return self._object(tp, self._override_fields(tp, fields))

    def typed_dict(
        self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str]
    ) -> Result:
        fields = [
            ObjectField(name, type_, name in required_keys, default=Undefined)
            for name, type_ in types.items()
        ]
        return self._object(tp, self._override_fields(tp, fields))

    def unsupported(self, tp: AnyType) -> Result:
        dummy: list = []
        fields = self._override_fields(tp, dummy)
        return super().unsupported(tp) if fields is dummy else self._object(tp, fields)


class DeserializationObjectVisitor(ObjectVisitor[Result]):
    _field_kind_filtered = FieldKind.READ_ONLY

    @staticmethod
    def _field_conversion(field: ObjectField) -> Optional[AnyConversion]:
        return field.deserialization

    @staticmethod
    def _skip_field(field: ObjectField) -> bool:
        return field.skip.deserialization


class SerializationObjectVisitor(ObjectVisitor[Result]):
    _field_kind_filtered = FieldKind.WRITE_ONLY

    @staticmethod
    def _field_conversion(field: ObjectField) -> Optional[AnyConversion]:
        return field.serialization

    @staticmethod
    def _skip_field(field: ObjectField) -> bool:
        return field.skip.serialization