from collections import defaultdict
from enum import Enum
from typing import (
    Any,
    Collection,
    Dict,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
)

from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.conversions.visitor import (
    ConversionsVisitor,
    DeserializationVisitor,
    SerializationVisitor,
)
from apischema.discriminators import (
    get_discriminated_parent,
    get_inherited_discriminator,
)
from apischema.json_schema.conversions_resolver import WithConversionsResolver
from apischema.metadata.keys import DISCRIMINATOR_METADATA
from apischema.objects import ObjectField
from apischema.objects.visitor import (
    DeserializationObjectVisitor,
    ObjectVisitor,
    SerializationObjectVisitor,
)
from apischema.type_names import TypeNameFactory, get_type_name
from apischema.types import AnyType
from apischema.utils import get_origin_or_type, is_hashable, replace_builtins
from apischema.visitor import Unsupported

try:
    from apischema.typing import Annotated, get_origin, is_union
except ImportError:
    Annotated = ...  # type: ignore

Refs = Dict[str, Tuple[AnyType, int]]


class Recursive(Exception):
    pass


T = TypeVar("T")


class RefsExtractor(ConversionsVisitor, ObjectVisitor, WithConversionsResolver):
    def __init__(self, default_conversion: DefaultConversion, refs: Refs):
        super().__init__(default_conversion)
        self.refs = refs
        self._rec_guard: Dict[
            Tuple[AnyType, Optional[AnyConversion]], int
        ] = defaultdict(lambda: 0)

    def _incr_ref(self, ref: Optional[str], tp: AnyType) -> bool:
        if ref is None:
            return False
        else:
            ref_cls, count = self.refs.get(ref, (tp, 0))
            if replace_builtins(ref_cls) != replace_builtins(tp):
                raise ValueError(
                    f"Types {tp} and {self.refs[ref][0]} share same reference '{ref}'"
                )
            self.refs[ref] = (ref_cls, count + 1)
            return count > 0

    def annotated(self, tp: AnyType, annotations: Sequence[Any]):
        for i, annotation in enumerate(reversed(annotations)):
            if isinstance(annotation, TypeNameFactory):
                ref = annotation.to_type_name(tp).json_schema
                if not isinstance(ref, str):
                    continue
                ref_annotations = annotations[: len(annotations) - i]
                annotated = Annotated[(tp, *ref_annotations)]
                if self._incr_ref(ref, annotated):
                    return
            if (
                isinstance(annotation, Mapping)
                and DISCRIMINATOR_METADATA in annotation
                and is_union(get_origin(tp))
            ):
                # Visit one more time discriminated union in order to ensure ref count > 1
                self.visit(tp)
        return super().annotated(tp, annotations)

    def any(self):
        pass

    def collection(self, cls: Type[Collection], value_type: AnyType):
        self.visit(value_type)

    def enum(self, cls: Type[Enum]):
        pass

    def literal(self, values: Sequence[Any]):
        pass

    def mapping(self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType):
        self.visit(key_type)
        self.visit(value_type)

    def object(self, tp: AnyType, fields: Sequence[ObjectField]):
        if parent := get_discriminated_parent(get_origin_or_type(tp)):
            self._incr_ref(get_type_name(parent).json_schema, parent)
        for field in fields:
            self.visit_with_conv(field.type, self._field_conversion(field))

    def primitive(self, cls: Type):
        pass

    def tuple(self, types: Sequence[AnyType]):
        for cls in types:
            self.visit(cls)

    def _visited_union(self, results: Sequence):
        pass

    def union(self, types: Sequence[AnyType]):
        super().union(types)
        if get_inherited_discriminator(types):
            # Visit one more time discriminated union in order to ensure ref count > 1
            super().union(types)

    def visit_conversion(
        self,
        tp: AnyType,
        conversion: Optional[Any],
        dynamic: bool,
        next_conversion: Optional[AnyConversion] = None,
    ):
        ref_types = []
        if not dynamic:
            for ref_tp in self.resolve_conversion(tp):
                ref_types.append(ref_tp)
                if self._incr_ref(get_type_name(ref_tp).json_schema, ref_tp):
                    return
        if not is_hashable(tp):
            return super().visit_conversion(tp, conversion, dynamic, next_conversion)
        # 2 because the first type encountered of the recursive cycle can have no ref
        # (see test_recursive_by_conversion_schema)
        if self._rec_guard[(tp, self._conversion)] > 2:
            raise TypeError(
                f"Recursive type {tp} needs a ref. "
                "You can supply one using the type_name() decorator."
            )
        self._rec_guard[(tp, self._conversion)] += 1
        try:
            super().visit_conversion(tp, conversion, dynamic, next_conversion)
        except Unsupported:
            for ref_tp in ref_types:
                self.refs.pop(get_type_name(ref_tp).json_schema, ...)  # type: ignore
        finally:
            self._rec_guard[(tp, self._conversion)] -= 1


class DeserializationRefsExtractor(
    RefsExtractor, DeserializationVisitor, DeserializationObjectVisitor
):
    pass


class SerializationRefsExtractor(
    RefsExtractor, SerializationVisitor, SerializationObjectVisitor
):
    pass
