1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
"""Implement a tiny subset of dataclasses_json for config."""
from collections.abc import Mapping, Sequence
from dataclasses import asdict, fields, is_dataclass
from typing import Any, Dict, Type
class DataClassJsonMixin:
"""Adds from_dict to dataclass."""
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Any:
"""Parse dataclasses recursively."""
kwargs: Dict[str, Any] = {}
cls_fields = {field.name: field for field in fields(cls)} # type: ignore[arg-type]
for key, value in data.items():
if key not in cls_fields:
# Skip unknown fields
continue
field = cls_fields[key]
if is_dataclass(field.type):
assert issubclass(field.type, DataClassJsonMixin), field.type # type: ignore[arg-type,union-attr]
kwargs[key] = field.type.from_dict(value) # type: ignore[union-attr]
else:
kwargs[key] = _decode(value, field.type) # type: ignore[arg-type]
# Fill in optional fields with None
for field in cls_fields.values():
if (field.name not in kwargs) and _is_optional(field.type): # type: ignore[arg-type]
kwargs[field.name] = None
return cls(**kwargs)
def to_dict(self) -> Dict[str, Any]:
"""Alias for asdict."""
return asdict(self) # type: ignore[call-overload]
def _decode(value: Any, target_type: Type) -> Any:
"""Decode value using (possibly generic) type."""
if is_dataclass(target_type):
assert issubclass(target_type, DataClassJsonMixin), target_type
return target_type.from_dict(value) if value is not None else None
if hasattr(target_type, "__args__"):
# Optional[T]
if type(None) in target_type.__args__:
optional_type = target_type.__args__[0]
return _decode(value, optional_type)
# List[T]
if isinstance(value, Sequence):
list_type = target_type.__args__[0]
return [_decode(item, list_type) for item in value]
# Dict[str, T]
if isinstance(value, Mapping):
value_type = target_type.__args__[1]
return {
map_key: _decode(map_value, value_type)
for map_key, map_value in value.items()
}
return value
def _is_optional(target_type: Type):
"""True if type is Optional"""
return hasattr(target_type, "__args__") and (type(None) in target_type.__args__)
|