1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
__all__ = ["fields_set", "is_set", "set_fields", "unset_fields", "with_fields_set"]
from dataclasses import ( # type: ignore
_FIELD,
_FIELD_INITVAR,
_FIELDS,
Field,
is_dataclass,
)
from functools import wraps
from inspect import signature
from typing import AbstractSet, Any, Collection, Set, Type, TypeVar, cast
from apischema.objects.fields import get_field_name
from apischema.utils import PREFIX
FIELDS_SET_ATTR = f"{PREFIX}fields_set"
_ALREADY_SET = f"{PREFIX}already_set"
Cls = TypeVar("Cls", bound=Type)
_fields_set_classes: Set[type] = set()
def support_fields_set(cls: type) -> bool:
return any(base in _fields_set_classes for base in cls.__mro__)
def with_fields_set(cls: Cls) -> Cls:
from apischema.metadata.keys import DEFAULT_AS_SET_METADATA
init_fields = set()
post_init_fields = set()
if is_dataclass(cls):
for field in getattr(cls, _FIELDS).values():
assert isinstance(field, Field)
if field._field_type == _FIELD_INITVAR: # type: ignore
init_fields.add(field.name)
if field._field_type == _FIELD and not field.init: # type: ignore
post_init_fields.add(field.name)
if field.metadata.get(DEFAULT_AS_SET_METADATA):
post_init_fields.add(field.name)
params = list(signature(cls.__init__).parameters)[1:]
old_new = cls.__new__
old_init = cls.__init__
old_setattr = cls.__setattr__
def new_new(*args, **kwargs):
if old_new is object.__new__:
obj = object.__new__(args[0])
else:
obj = old_new(*args, **kwargs)
# Initialize FIELD_SET_ATTR in order to prevent inherited class which override
# __init__ to raise in __setattr__
obj.__dict__[FIELDS_SET_ATTR] = set()
return obj
def new_init(self, *args, **kwargs):
prev_fields_set = self.__dict__.get(FIELDS_SET_ATTR, set()).copy()
self.__dict__[FIELDS_SET_ATTR] = set()
try:
old_init(self, *args, **kwargs)
except TypeError as err:
if str(err) == no_dataclass_init_error:
raise RuntimeError(dataclass_before_error) from None
else:
raise
arg_fields = {*params[: len(args)], *kwargs} - init_fields
self.__dict__[FIELDS_SET_ATTR] = prev_fields_set | arg_fields | post_init_fields
def new_setattr(self, attr, value):
try:
self.__dict__[FIELDS_SET_ATTR].add(attr)
except KeyError:
raise RuntimeError(dataclass_before_error) from None
old_setattr(self, attr, value) # type: ignore
for attr, old, new in [
("__new__", old_new, new_new),
("__init__", old_init, new_init),
("__setattr__", old_setattr, new_setattr),
]:
if hasattr(old, _ALREADY_SET):
continue
setattr(new, _ALREADY_SET, True)
setattr(cls, attr, wraps(old)(new))
_fields_set_classes.add(cls)
return cls
no_dataclass_init_error = (
"object.__init__() takes exactly one argument (the instance to initialize)"
)
dataclass_before_error = (
f"{with_fields_set.__name__} must be put before dataclass decorator"
)
T = TypeVar("T")
def _field_names(fields: Collection) -> AbstractSet[str]:
result: Set[str] = set()
for field in fields:
result.add(get_field_name(field))
return result
def _fields_set(obj: Any) -> Set[str]:
try:
return getattr(obj, FIELDS_SET_ATTR)
except AttributeError:
raise TypeError(
f"Type {obj.__class__} is not decorated" f" with {with_fields_set.__name__}"
)
def set_fields(obj: T, *fields: Any, overwrite=False) -> T:
if overwrite:
_fields_set(obj).clear()
_fields_set(obj).update(map(get_field_name, fields))
return obj
def unset_fields(obj: T, *fields: Any) -> T:
_fields_set(obj).difference_update(map(get_field_name, fields))
return obj
# This could just be an alias with a specified type, but it's better handled by IDE
# like this
def fields_set(obj: Any) -> AbstractSet[str]:
return _fields_set(obj)
class FieldIsSet:
def __init__(self, obj: Any):
self.fields_set = fields_set(obj)
def __getattribute__(self, name: str) -> bool:
return name in object.__getattribute__(self, "fields_set")
def is_set(obj: T) -> T:
return cast(T, FieldIsSet(obj))
|