from collections import ChainMap
from dataclasses import MISSING, Field, InitVar, dataclass, field
from enum import Enum, auto
from types import FunctionType
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Iterable,
    Mapping,
    MutableMapping,
    NoReturn,
    Optional,
    Pattern,
    Sequence,
    Union,
    cast,
)

from apischema.cache import CacheAwareDict
from apischema.conversions.conversions import AnyConversion
from apischema.metadata.implem import (
    ConversionMetadata,
    SkipMetadata,
    ValidatorsMetadata,
)
from apischema.metadata.keys import (
    ALIAS_METADATA,
    ALIAS_NO_OVERRIDE_METADATA,
    CONVERSION_METADATA,
    DEFAULT_AS_SET_METADATA,
    FALL_BACK_ON_DEFAULT_METADATA,
    FLATTEN_METADATA,
    NONE_AS_UNDEFINED_METADATA,
    ORDERING_METADATA,
    POST_INIT_METADATA,
    PROPERTIES_METADATA,
    REQUIRED_METADATA,
    SCHEMA_METADATA,
    SKIP_METADATA,
    VALIDATORS_METADATA,
)
from apischema.types import AnyType, NoneType, UndefinedType
from apischema.typing import get_args, is_annotated
from apischema.utils import (
    LazyValue,
    empty_dict,
    get_args2,
    is_union_of,
    keep_annotations,
)

if TYPE_CHECKING:
    from apischema.ordering import Ordering
    from apischema.schemas import Schema
    from apischema.validation.validators import Validator


class FieldKind(Enum):
    NORMAL = auto()
    READ_ONLY = auto()
    WRITE_ONLY = auto()


# Cannot reuse MISSING for dataclass field because it would be interpreted as no default
MISSING_DEFAULT = object()


@dataclass(frozen=True)
class ObjectField:
    name: str
    type: AnyType
    required: bool = True
    metadata: Mapping[str, Any] = field(default_factory=lambda: empty_dict)
    default: InitVar[Any] = MISSING_DEFAULT
    default_factory: Optional[Callable[[], Any]] = None
    kind: FieldKind = FieldKind.NORMAL

    def __post_init__(self, default: Any):
        if REQUIRED_METADATA in self.full_metadata:
            object.__setattr__(self, "required", True)
        if self.default_factory is MISSING:
            object.__setattr__(self, "default_factory", None)
        if not self.required and self.default_factory is None:
            if default is MISSING_DEFAULT or default is MISSING:
                raise ValueError("Missing default for non-required ObjectField")
            object.__setattr__(self, "default_factory", LazyValue(default))
        if self.none_as_undefined and is_union_of(self.type, NoneType):
            new_type = Union[tuple(a for a in get_args2(self.type) if a != NoneType)]  # type: ignore
            object.__setattr__(self, "type", keep_annotations(new_type, self.type))

    @property
    def full_metadata(self) -> Mapping[str, Any]:
        if not is_annotated(self.type):
            return self.metadata
        return ChainMap(
            cast(MutableMapping, self.metadata),
            *(
                cast(MutableMapping, arg)
                for arg in reversed(get_args(self.type)[1:])
                if isinstance(arg, Mapping)
            ),
        )

    @property
    def additional_properties(self) -> bool:
        return self.full_metadata.get(PROPERTIES_METADATA, ...) is None

    @property
    def alias(self) -> str:
        return self.full_metadata.get(ALIAS_METADATA, self.name)

    @property
    def override_alias(self) -> bool:
        return ALIAS_NO_OVERRIDE_METADATA not in self.full_metadata

    @property
    def _conversion(self) -> Optional[ConversionMetadata]:
        return self.metadata.get(CONVERSION_METADATA)

    @property
    def default_as_set(self) -> bool:
        return DEFAULT_AS_SET_METADATA in self.full_metadata

    @property
    def deserialization(self) -> Optional[AnyConversion]:
        conversion = self._conversion
        return conversion.deserialization if conversion is not None else None

    @property
    def fall_back_on_default(self) -> bool:
        return (
            FALL_BACK_ON_DEFAULT_METADATA in self.full_metadata
            and self.default_factory is not None
        )

    @property
    def flattened(self) -> bool:
        return FLATTEN_METADATA in self.full_metadata

    def get_default(self) -> Any:
        if self.required:
            raise RuntimeError("Field is required")
        assert self.default_factory is not None
        return self.default_factory()

    @property
    def is_aggregate(self) -> bool:
        return (
            self.flattened
            or self.additional_properties
            or self.pattern_properties is not None
        )

    @property
    def none_as_undefined(self):
        return NONE_AS_UNDEFINED_METADATA in self.full_metadata

    @property
    def ordering(self) -> Optional["Ordering"]:
        return self.full_metadata.get(ORDERING_METADATA)

    @property
    def post_init(self) -> bool:
        return POST_INIT_METADATA in self.full_metadata

    @property
    def pattern_properties(self) -> Union[Pattern, "ellipsis", None]:  # noqa: F821
        return self.full_metadata.get(PROPERTIES_METADATA)

    @property
    def schema(self) -> Optional["Schema"]:
        return self.metadata.get(SCHEMA_METADATA)

    @property
    def serialization(self) -> Optional[AnyConversion]:
        conversion = self._conversion
        return conversion.serialization if conversion is not None else None

    @property
    def skip(self) -> SkipMetadata:
        return self.metadata.get(SKIP_METADATA, SkipMetadata())

    def skippable(self, default: bool, none: bool) -> bool:
        return bool(
            self.skip.serialization_if
            or is_union_of(self.type, UndefinedType)
            or (
                self.default_factory is not None
                and (self.skip.serialization_default or default)
            )
            or self.none_as_undefined
            or (none and is_union_of(self.type, NoneType))
        )

    @property
    def undefined(self) -> bool:
        return is_union_of(self.type, UndefinedType)

    @property
    def validators(self) -> Sequence["Validator"]:
        if VALIDATORS_METADATA in self.metadata:
            return cast(
                ValidatorsMetadata, self.metadata[VALIDATORS_METADATA]
            ).validators
        else:
            return ()


FieldOrName = Union[str, ObjectField, Field]


def _bad_field(obj: Any, methods: bool) -> NoReturn:
    method_types = "property/types.FunctionType" if methods else ""
    raise TypeError(
        f"Expected dataclasses.Field/apischema.ObjectField/str{method_types}, found {obj}"
    )


def check_field_or_name(field_or_name: Any, *, methods: bool = False):
    method_types = (property, FunctionType) if methods else ()
    if not isinstance(field_or_name, (str, ObjectField, Field, *method_types)):
        _bad_field(field_or_name, methods)


def get_field_name(field_or_name: Any, *, methods: bool = False) -> str:
    if isinstance(field_or_name, (Field, ObjectField)):
        return field_or_name.name
    elif isinstance(field_or_name, str):
        return field_or_name
    elif (
        methods
        and isinstance(field_or_name, property)
        and field_or_name.fget is not None
    ):
        return field_or_name.fget.__name__
    elif methods and isinstance(field_or_name, FunctionType):
        return field_or_name.__name__
    else:
        _bad_field(field_or_name, methods)


_class_fields: MutableMapping[
    type, Callable[[], Sequence[ObjectField]]
] = CacheAwareDict({})


def set_object_fields(
    cls: type,
    fields: Union[Iterable[ObjectField], Callable[[], Sequence[ObjectField]], None],
):
    if fields is None:
        _class_fields.pop(cls, ...)
    elif callable(fields):
        _class_fields[cls] = fields
    else:
        _class_fields[cls] = lambda fields=tuple(fields): fields  # type: ignore


def default_object_fields(cls: type) -> Optional[Sequence[ObjectField]]:
    return _class_fields[cls]() if cls in _class_fields else None
