File: fields.py

package info (click to toggle)
python-apischema 0.18.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,636 kB
  • sloc: python: 15,281; makefile: 3; sh: 2
file content (145 lines) | stat: -rw-r--r-- 4,434 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
__all__ = ["fields_set", "is_set", "set_fields", "unset_fields", "with_fields_set"]
from dataclasses import (  # type: ignore
    _FIELD,
    _FIELD_INITVAR,
    _FIELDS,
    Field,
    is_dataclass,
)
from functools import wraps
from inspect import signature
from typing import AbstractSet, Any, Collection, Set, Type, TypeVar, cast

from apischema.objects.fields import get_field_name
from apischema.utils import PREFIX

FIELDS_SET_ATTR = f"{PREFIX}fields_set"
_ALREADY_SET = f"{PREFIX}already_set"

Cls = TypeVar("Cls", bound=Type)

_fields_set_classes: Set[type] = set()


def support_fields_set(cls: type) -> bool:
    return any(base in _fields_set_classes for base in cls.__mro__)


def with_fields_set(cls: Cls) -> Cls:
    from apischema.metadata.keys import DEFAULT_AS_SET_METADATA

    init_fields = set()
    post_init_fields = set()
    if is_dataclass(cls):
        for field in getattr(cls, _FIELDS).values():
            assert isinstance(field, Field)
            if field._field_type == _FIELD_INITVAR:  # type: ignore
                init_fields.add(field.name)
            if field._field_type == _FIELD and not field.init:  # type: ignore
                post_init_fields.add(field.name)
            if field.metadata.get(DEFAULT_AS_SET_METADATA):
                post_init_fields.add(field.name)
    params = list(signature(cls.__init__).parameters)[1:]
    old_new = cls.__new__
    old_init = cls.__init__
    old_setattr = cls.__setattr__

    def new_new(*args, **kwargs):
        if old_new is object.__new__:
            obj = object.__new__(args[0])
        else:
            obj = old_new(*args, **kwargs)
        # Initialize FIELD_SET_ATTR in order to prevent inherited class which override
        # __init__ to raise in __setattr__
        obj.__dict__[FIELDS_SET_ATTR] = set()
        return obj

    def new_init(self, *args, **kwargs):
        prev_fields_set = self.__dict__.get(FIELDS_SET_ATTR, set()).copy()
        self.__dict__[FIELDS_SET_ATTR] = set()
        try:
            old_init(self, *args, **kwargs)
        except TypeError as err:
            if str(err) == no_dataclass_init_error:
                raise RuntimeError(dataclass_before_error) from None
            else:
                raise
        arg_fields = {*params[: len(args)], *kwargs} - init_fields
        self.__dict__[FIELDS_SET_ATTR] = prev_fields_set | arg_fields | post_init_fields

    def new_setattr(self, attr, value):
        try:
            self.__dict__[FIELDS_SET_ATTR].add(attr)
        except KeyError:
            raise RuntimeError(dataclass_before_error) from None
        old_setattr(self, attr, value)  # type: ignore

    for attr, old, new in [
        ("__new__", old_new, new_new),
        ("__init__", old_init, new_init),
        ("__setattr__", old_setattr, new_setattr),
    ]:
        if hasattr(old, _ALREADY_SET):
            continue
        setattr(new, _ALREADY_SET, True)
        setattr(cls, attr, wraps(old)(new))

    _fields_set_classes.add(cls)
    return cls


no_dataclass_init_error = (
    "object.__init__() takes exactly one argument (the instance to initialize)"
)
dataclass_before_error = (
    f"{with_fields_set.__name__} must be put before dataclass decorator"
)


T = TypeVar("T")


def _field_names(fields: Collection) -> AbstractSet[str]:
    result: Set[str] = set()
    for field in fields:
        result.add(get_field_name(field))
    return result


def _fields_set(obj: Any) -> Set[str]:
    try:
        return getattr(obj, FIELDS_SET_ATTR)
    except AttributeError:
        raise TypeError(
            f"Type {obj.__class__} is not decorated" f" with {with_fields_set.__name__}"
        )


def set_fields(obj: T, *fields: Any, overwrite=False) -> T:
    if overwrite:
        _fields_set(obj).clear()
    _fields_set(obj).update(map(get_field_name, fields))
    return obj


def unset_fields(obj: T, *fields: Any) -> T:
    _fields_set(obj).difference_update(map(get_field_name, fields))
    return obj


# This could just be an alias with a specified type, but it's better handled by IDE
# like this
def fields_set(obj: Any) -> AbstractSet[str]:
    return _fields_set(obj)


class FieldIsSet:
    def __init__(self, obj: Any):
        self.fields_set = fields_set(obj)

    def __getattribute__(self, name: str) -> bool:
        return name in object.__getattribute__(self, "fields_set")


def is_set(obj: T) -> T:
    return cast(T, FieldIsSet(obj))