File: getters.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (151 lines) | stat: -rw-r--r-- 4,137 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import inspect
from typing import (
    Any,
    Callable,
    Mapping,
    Optional,
    Sequence,
    Type,
    TypeVar,
    Union,
    cast,
    overload,
)

from apischema.cache import cache
from apischema.metadata import properties
from apischema.objects.fields import ObjectField
from apischema.objects.visitor import ObjectVisitor
from apischema.types import AnyType
from apischema.typing import _GenericAlias, get_type_hints
from apischema.utils import empty_dict
from apischema.visitor import Unsupported


@cache
def object_fields(
    tp: AnyType,
    deserialization: bool = False,
    serialization: bool = False,
    default: Optional[
        Callable[[type], Optional[Sequence[ObjectField]]]
    ] = ObjectVisitor._default_fields,
) -> Mapping[str, ObjectField]:
    class GetFields(ObjectVisitor[Sequence[ObjectField]]):
        def _skip_field(self, field: ObjectField) -> bool:
            return (field.skip.deserialization and serialization) or (
                field.skip.serialization and deserialization
            )

        @staticmethod
        def _default_fields(cls: type) -> Optional[Sequence[ObjectField]]:
            return None if default is None else default(cls)

        def object(
            self, cls: Type, fields: Sequence[ObjectField]
        ) -> Sequence[ObjectField]:
            return fields

    try:
        return {f.name: f for f in GetFields().visit(tp)}
    except (Unsupported, NotImplementedError):
        raise TypeError(f"{tp} doesn't have fields")


def object_fields2(obj: Any) -> Mapping[str, ObjectField]:
    return object_fields(
        obj if isinstance(obj, (type, _GenericAlias)) else obj.__class__
    )


T = TypeVar("T")


class FieldGetter:
    def __init__(self, obj: Any):
        self.fields = object_fields2(obj)

    def __getattribute__(self, name: str) -> ObjectField:
        try:
            return object.__getattribute__(self, "fields")[name]
        except KeyError:
            raise AttributeError(name)


@overload
def get_field(obj: Type[T]) -> T:
    ...


@overload
def get_field(obj: T) -> T:
    ...


# Overload because of Mypy issue
# https://github.com/python/mypy/issues/9003#issuecomment-667418520
def get_field(obj: Union[Type[T], T]) -> T:
    return cast(T, FieldGetter(obj))


class AliasedStr(str):
    pass


class AliasGetter:
    def __init__(self, obj: Any):
        self.fields = object_fields2(obj)

    def __getattribute__(self, name: str) -> str:
        try:
            return AliasedStr(object.__getattribute__(self, "fields")[name].alias)
        except KeyError:
            raise AttributeError(name)


@overload
def get_alias(obj: Type[T]) -> T:
    ...


@overload
def get_alias(obj: T) -> T:
    ...


def get_alias(obj: Union[Type[T], T]) -> T:
    return cast(T, AliasGetter(obj))


def parameters_as_fields(
    func: Callable, parameters_metadata: Optional[Mapping[str, Mapping]] = None
) -> Sequence[ObjectField]:
    parameters_metadata = parameters_metadata or {}
    types = get_type_hints(func, include_extras=True)
    fields = []
    for param_name, param in inspect.signature(func).parameters.items():
        if param.kind is inspect.Parameter.POSITIONAL_ONLY:
            raise TypeError("Positional only parameters are not supported")
        param_type = types.get(param_name, Any)
        if param.kind in {
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
            inspect.Parameter.KEYWORD_ONLY,
        }:
            field = ObjectField(
                param_name,
                param_type,
                param.default is inspect.Parameter.empty,
                parameters_metadata.get(param_name, empty_dict),
                default=param.default,
            )
            fields.append(field)
        elif param.kind == inspect.Parameter.VAR_KEYWORD:
            field = ObjectField(
                param_name,
                Mapping[str, param_type],  # type: ignore
                False,
                properties | parameters_metadata.get(param_name, empty_dict),
                default_factory=dict,
            )
            fields.append(field)
    return fields