1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
import decimal
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, Type, Union
from typing_extensions import Literal
from mashumaro.core.const import Sentinel
__all__ = [
"SerializableType",
"GenericSerializableType",
"SerializationStrategy",
"RoundedDecimal",
"Discriminator",
"Alias",
]
class SerializableType:
__slots__ = ()
__use_annotations__ = False
def __init_subclass__(
cls,
use_annotations: Union[
bool, Literal[Sentinel.MISSING]
] = Sentinel.MISSING,
**kwargs: Any,
):
if use_annotations is not Sentinel.MISSING:
cls.__use_annotations__ = use_annotations
def _serialize(self) -> Any:
raise NotImplementedError
@classmethod
def _deserialize(cls, value: Any) -> Any:
raise NotImplementedError
class GenericSerializableType:
__slots__ = ()
def _serialize(self, types: list[Type]) -> Any:
raise NotImplementedError
@classmethod
def _deserialize(cls, value: Any, types: list[Type]) -> Any:
raise NotImplementedError
class SerializationStrategy:
__use_annotations__ = False
def __init_subclass__(
cls,
use_annotations: Union[
bool, Literal[Sentinel.MISSING]
] = Sentinel.MISSING,
**kwargs: Any,
):
if use_annotations is not Sentinel.MISSING:
cls.__use_annotations__ = use_annotations
def serialize(self, value: Any) -> Any:
raise NotImplementedError
def deserialize(self, value: Any) -> Any:
raise NotImplementedError
class RoundedDecimal(SerializationStrategy):
def __init__(
self, places: Optional[int] = None, rounding: Optional[str] = None
):
if places is not None:
self.exp = decimal.Decimal((0, (1,), -places))
else:
self.exp = None # type: ignore
self.rounding = rounding
def serialize(self, value: decimal.Decimal) -> str:
if self.exp:
if self.rounding:
return str(value.quantize(self.exp, rounding=self.rounding))
else:
return str(value.quantize(self.exp))
else:
return str(value)
def deserialize(self, value: str) -> decimal.Decimal:
return decimal.Decimal(str(value))
@dataclass(unsafe_hash=True)
class Discriminator:
field: Optional[str] = None
include_supertypes: bool = False
include_subtypes: bool = False
variant_tagger_fn: Optional[Callable[[Any], Any]] = None
def __post_init__(self) -> None:
if not self.include_supertypes and not self.include_subtypes:
raise ValueError(
"Either 'include_supertypes' or 'include_subtypes' "
"must be enabled"
)
class Alias:
def __init__(self, name: str, /):
self.name = name
def __repr__(self) -> str:
return f"Alias(name='{self.name}')"
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Alias):
return False
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
|