from typing import Any

from litestar.plugins import SerializationPlugin
from litestar.typing import FieldDefinition
from sqlalchemy.orm import DeclarativeBase

from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO
from advanced_alchemy.extensions.litestar.plugins import _slots_base


class SQLAlchemySerializationPlugin(SerializationPlugin, _slots_base.SlotsBase):
    def __init__(self) -> None:
        self._type_dto_map: dict[type[DeclarativeBase], type[SQLAlchemyDTO[Any]]] = {}

    def supports_type(self, field_definition: FieldDefinition) -> bool:
        return (
            field_definition.is_collection and field_definition.has_inner_subclass_of(DeclarativeBase)
        ) or field_definition.is_subclass_of(DeclarativeBase)

    def create_dto_for_type(self, field_definition: FieldDefinition) -> type[SQLAlchemyDTO[Any]]:
        # assumes that the type is a container of SQLAlchemy models or a single SQLAlchemy model
        annotation = next(
            (
                inner_type.annotation
                for inner_type in field_definition.inner_types
                if inner_type.is_subclass_of(DeclarativeBase)
            ),
            field_definition.annotation,
        )
        if annotation in self._type_dto_map:
            return self._type_dto_map[annotation]

        self._type_dto_map[annotation] = dto_type = SQLAlchemyDTO[annotation]  # type:ignore[valid-type]

        return dto_type
