1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
from dataclasses import Field, fields, is_dataclass
from inspect import isclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, final, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Set as AbstractSet
from typing_extensions import TypeAlias, TypeGuard
__all__ = (
"DataclassProtocol",
"Empty",
"EmptyType",
"extract_dataclass_fields",
"extract_dataclass_items",
"is_dataclass_class",
"is_dataclass_instance",
"simple_asdict",
)
@final
class Empty:
"""A sentinel class used as placeholder."""
EmptyType: "TypeAlias" = type[Empty]
"""Type alias for the :class:`~advanced_alchemy.utils.dataclass.Empty` sentinel class."""
@runtime_checkable
class DataclassProtocol(Protocol):
"""Protocol for instance checking dataclasses"""
__dataclass_fields__: "ClassVar[dict[str, Any]]"
def extract_dataclass_fields(
dt: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
include: "Optional[AbstractSet[str]]" = None,
exclude: "Optional[AbstractSet[str]]" = None,
) -> "tuple[Field[Any], ...]":
"""Extract dataclass fields.
Args:
dt: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
include: An iterable of fields to include.
exclude: An iterable of fields to exclude.
Returns:
A tuple of dataclass fields.
"""
include = include or set()
exclude = exclude or set()
if common := (include & exclude):
msg = f"Fields {common} are both included and excluded."
raise ValueError(msg)
dataclass_fields: Iterable[Field[Any]] = fields(dt)
if exclude_none:
dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not None)
if exclude_empty:
dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not Empty)
if include:
dataclass_fields = (field for field in dataclass_fields if field.name in include)
if exclude:
dataclass_fields = (field for field in dataclass_fields if field.name not in exclude)
return tuple(dataclass_fields)
def extract_dataclass_items(
dt: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
include: "Optional[AbstractSet[str]]" = None,
exclude: "Optional[AbstractSet[str]]" = None,
) -> tuple[tuple[str, Any], ...]:
"""Extract dataclass name, value pairs.
Unlike the 'asdict' method exports by the stdlib, this function does not pickle values.
Args:
dt: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
include: An iterable of fields to include.
exclude: An iterable of fields to exclude.
Returns:
A tuple of key/value pairs.
"""
dataclass_fields = extract_dataclass_fields(dt, exclude_none, exclude_empty, include, exclude)
return tuple((field.name, getattr(dt, field.name)) for field in dataclass_fields)
def simple_asdict(
obj: "DataclassProtocol",
exclude_none: bool = False,
exclude_empty: bool = False,
convert_nested: bool = True,
exclude: "Optional[AbstractSet[str]]" = None,
) -> "dict[str, Any]":
"""Convert a dataclass to a dictionary.
This method has important differences to the standard library version:
- it does not deepcopy values
- it does not recurse into collections
Args:
obj: :class:`DataclassProtocol` instance.
exclude_none: Whether to exclude None values.
exclude_empty: Whether to exclude Empty values.
convert_nested: Whether to recursively convert nested dataclasses.
exclude: An iterable of fields to exclude.
Returns:
A dictionary of key/value pairs.
"""
ret: dict[str, Any] = {}
for field in extract_dataclass_fields(obj, exclude_none, exclude_empty, exclude=exclude):
value = getattr(obj, field.name)
if is_dataclass_instance(value) and convert_nested:
ret[field.name] = simple_asdict(value, exclude_none, exclude_empty)
else:
ret[field.name] = getattr(obj, field.name)
return ret
def is_dataclass_instance(obj: Any) -> "TypeGuard[DataclassProtocol]":
"""Check if an object is a dataclass instance.
Args:
obj: An object to check.
Returns:
True if the object is a dataclass instance.
"""
return hasattr(type(obj), "__dataclass_fields__") # pyright: ignore[reportUnknownArgumentType]
def is_dataclass_class(annotation: Any) -> "TypeGuard[type[DataclassProtocol]]":
"""Wrap :func:`is_dataclass <dataclasses.is_dataclass>` in a :data:`typing.TypeGuard`.
Args:
annotation: tested to determine if instance or type of :class:`dataclasses.dataclass`.
Returns:
``True`` if instance or type of ``dataclass``.
"""
try:
return isclass(annotation) and is_dataclass(annotation)
except TypeError: # pragma: no cover
return False
|