import collections.abc
import sys
import typing
from collections.abc import Sequence
from datetime import datetime, timedelta
from enum import Enum
from functools import partial
from inspect import isclass
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Iterable,
    Literal,
    Optional,
    Union,
    get_args,
    get_origin,
)

from cyclopts.annotations import is_annotated, is_nonetype, is_union, resolve
from cyclopts.exceptions import CoercionError, ValidationError
from cyclopts.field_info import get_field_infos
from cyclopts.utils import UNSET, default_name_transform, grouper, is_builtin

if sys.version_info >= (3, 12):  # pragma: no cover
    from typing import TypeAliasType
else:  # pragma: no cover
    TypeAliasType = None

if TYPE_CHECKING:
    from cyclopts.argument import Token


_implicit_iterable_type_mapping: dict[type, type] = {
    Iterable: list[str],
    typing.Sequence: list[str],
    Sequence: list[str],
    frozenset: frozenset[str],
    list: list[str],
    set: set[str],
    tuple: tuple[str, ...],
}

ITERABLE_TYPES = {
    Iterable,
    typing.Sequence,
    Sequence,
    frozenset,
    list,
    set,
    tuple,
}

NestedCliArgs = dict[str, Union[Sequence[str], "NestedCliArgs"]]


def _bool(s: str) -> bool:
    s = s.lower()
    if s in {"no", "n", "0", "false", "f"}:
        return False
    elif s in {"yes", "y", "1", "true", "t"}:
        return True
    else:
        # Cyclopts is a little bit conservative when coercing strings into boolean.
        raise CoercionError(target_type=bool)


def _int(s: str) -> int:
    s = s.lower()
    if s.startswith("0x"):
        return int(s, 16)
    elif s.startswith("0o"):
        return int(s, 8)
    elif s.startswith("0b"):
        return int(s, 2)
    else:
        # Casting to a float first allows for things like "30.0"
        return int(round(float(s)))


def _bytes(s: str) -> bytes:
    return bytes(s, encoding="utf8")


def _bytearray(s: str) -> bytearray:
    return bytearray(_bytes(s))


def _datetime(s: str) -> datetime:
    """Parse a datetime string.

    Returns
    -------
    datetime.datetime
    """
    formats = [
        # ISO 8601 formats (unambiguous internationally)
        "%Y-%m-%d",  # 1956-01-31
        "%Y-%m-%dT%H:%M:%S",  # 1956-01-31T10:00:00
        "%Y-%m-%d %H:%M:%S",  # 1956-01-31 10:00:00
        "%Y-%m-%dT%H:%M:%S%z",  # 1956-01-31T10:00:00+0000
        "%Y-%m-%dT%H:%M:%S.%f",  # 1956-01-31T10:00:00.123456
        "%Y-%m-%dT%H:%M:%S.%f%z",  # 1956-01-31T10:00:00.123456+0000
    ]

    for fmt in formats:
        try:
            return datetime.strptime(s, fmt)
        except ValueError:
            continue

    raise ValueError


def _timedelta(s: str) -> timedelta:
    """Parse a timedelta string."""
    import re

    negative = False
    if s.startswith("-"):
        negative = True
        s = s[1:]

    matches = re.findall(r"((\d+\.\d+|\d+)([smhdwMy]))", s)

    if not matches:
        raise ValueError(f"Could not parse duration string: {s}")

    seconds = 0
    for _, value, unit in matches:
        value = float(value)
        if unit == "s":
            seconds += value
        elif unit == "m":
            seconds += value * 60
        elif unit == "h":
            seconds += value * 3600
        elif unit == "d":
            seconds += value * 86400
        elif unit == "w":
            seconds += value * 604800
        elif unit == "M":
            # Approximation: 1 month = 30 days
            seconds += value * 2592000
        elif unit == "y":
            # Approximation: 1 year = 365 days
            seconds += value * 31536000

    if negative:
        seconds = -seconds
    return timedelta(seconds=seconds)


# For types that need more logic than just invoking their type
_converters: dict[Any, Callable] = {
    bool: _bool,
    int: _int,
    bytes: _bytes,
    bytearray: _bytearray,
    datetime: _datetime,
    timedelta: _timedelta,
}


def _convert_tuple(
    type_: type[Any],
    *tokens: "Token",
    converter: Optional[Callable[[type, str], Any]],
    name_transform: Callable[[str], str],
) -> tuple:
    convert = partial(_convert, converter=converter, name_transform=name_transform)
    inner_types = tuple(x for x in get_args(type_) if x is not ...)
    inner_token_count, consume_all = token_count(type_)
    if consume_all:
        # variable-length tuple (list-like)
        remainder = len(tokens) % inner_token_count
        if remainder:
            raise CoercionError(
                msg=f"Incorrect number of arguments: expected multiple of {inner_token_count} but got {len(tokens)}."
            )
        if len(inner_types) == 1:
            inner_type = inner_types[0]
        elif len(inner_types) == 0:
            inner_type = str
        else:
            raise ValueError("A tuple must have 0 or 1 inner-types.")

        return tuple(
            convert(inner_type, chunk[0] if inner_token_count == 1 else chunk)
            for chunk in grouper(tokens, inner_token_count)
        )
    else:
        # Fixed-length tuple
        if inner_token_count != len(tokens):
            raise CoercionError(
                msg=f"Incorrect number of arguments: expected {inner_token_count} but got {len(tokens)}."
            )
        args_per_convert = [token_count(x)[0] for x in inner_types]
        it = iter(tokens)
        batched = [[next(it) for _ in range(size)] for size in args_per_convert]
        batched = [elem[0] if len(elem) == 1 else elem for elem in batched]
        out = tuple(convert(inner_type, arg) for inner_type, arg in zip(inner_types, batched))
    return out


def _convert(
    type_,
    token: Union["Token", Sequence["Token"]],
    *,
    converter: Optional[Callable[[Any, str], Any]],
    name_transform: Callable[[str], str],
):
    """Inner recursive conversion function for public ``convert``.

    Parameters
    ----------
    converter: Callable
    name_transform: Callable
    """
    from cyclopts.argument import Token
    from cyclopts.parameter import Parameter

    converter_needs_token = False
    if is_annotated(type_):
        from cyclopts.parameter import Parameter

        type_, cparam = Parameter.from_annotation(type_)
        if cparam.converter:
            converter_needs_token = True

            def converter_with_token(t_, value):
                assert cparam.converter
                return cparam.converter(t_, (value,))

            converter = converter_with_token

        if cparam.name_transform:
            name_transform = cparam.name_transform
    else:
        cparam = None

    convert = partial(_convert, converter=converter, name_transform=name_transform)
    convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform)

    origin_type = get_origin(type_)
    # Inner types **may** be ``Annotated``
    inner_types = get_args(type_)

    if type_ is dict:
        out = convert(dict[str, str], token)
    elif type_ in _implicit_iterable_type_mapping:
        out = convert(_implicit_iterable_type_mapping[type_], token)
    elif origin_type in (collections.abc.Iterable, collections.abc.Sequence):
        assert len(inner_types) == 1
        out = convert(list[inner_types[0]], token)  # pyright: ignore[reportGeneralTypeIssues]
    elif TypeAliasType is not None and isinstance(type_, TypeAliasType):
        out = convert(type_.__value__, token)
    elif is_union(origin_type):
        for t in inner_types:
            if is_nonetype(t):
                continue
            try:
                out = convert(t, token)
                break
            except Exception:
                pass
        else:
            if isinstance(token, Sequence):
                raise ValueError  # noqa: TRY004
            raise CoercionError(token=token, target_type=type_)
    elif origin_type is Literal:
        # Try coercing the token into each allowed Literal value (left-to-right).
        last_coercion_error = None
        for choice in get_args(type_):
            try:
                res = convert(type(choice), token)
            except CoercionError as e:
                last_coercion_error = e
                continue
            if res == choice:
                out = res
                break
        else:
            if last_coercion_error:
                last_coercion_error.target_type = type_
                raise last_coercion_error
            else:
                raise CoercionError(token=token[0] if isinstance(token, Sequence) else token, target_type=type_)
    elif origin_type is tuple:
        if isinstance(token, Token):
            # E.g. Tuple[str] (Annotation: tuple containing a single string)
            out = convert_tuple(type_, token, converter=converter)
        else:
            out = convert_tuple(type_, *token, converter=converter)
    elif origin_type in ITERABLE_TYPES:
        # NOT including tuple; handled in ``origin_type is tuple`` body above.
        count, _ = token_count(inner_types[0])
        if not isinstance(token, Sequence):
            raise ValueError
        if count > 1:
            gen = zip(*[iter(token)] * count)
        else:
            gen = token
        out = origin_type(convert(inner_types[0], e) for e in gen)  # pyright: ignore[reportOptionalCall]
    elif isclass(type_) and issubclass(type_, Enum):
        if isinstance(token, Sequence):
            raise ValueError

        if converter is None:
            element_transformed = name_transform(token.value)
            for member in type_:
                if name_transform(member.name) == element_transformed:
                    out = member
                    break
            else:
                raise CoercionError(token=token, target_type=type_)
        else:
            out = converter(type_, token.value)
    elif is_builtin(type_):
        assert isinstance(token, Token)
        try:
            if token.implicit_value is not UNSET:
                out = token.implicit_value
            elif converter is None:
                out = _converters.get(type_, type_)(token.value)
            elif converter_needs_token:
                out = converter(type_, token)  # pyright: ignore[reportArgumentType]
            else:
                out = converter(type_, token.value)
        except CoercionError as e:
            if e.target_type is None:
                e.target_type = type_
            if e.token is None:
                e.token = token
            raise
        except ValueError:
            raise CoercionError(token=token, target_type=type_) from None
    else:
        # Convert it into a user-supplied class.
        if not isinstance(token, Sequence):
            token = [token]
        i = 0
        pos_values = []
        hint = type_
        for field_info in get_field_infos(type_).values():
            hint = field_info.hint
            if isclass(hint) and issubclass(hint, str):  # Avoids infinite recursion
                pos_values.append(token[i].value)
                i += 1
            else:
                tokens_per_element, consume_all = token_count(hint)
                if tokens_per_element == 1:
                    pos_values.append(convert(hint, token[i]))
                    i += 1
                else:
                    pos_values.append(convert(hint, token[i : i + tokens_per_element]))
                    i += tokens_per_element
                if consume_all:
                    break
            if i == len(token):
                break
        assert i == len(token)
        out = type_(*pos_values)

    if cparam:
        # An inner type may have an independent Parameter annotation;
        # e.g.:
        #    Uint8 = Annotated[int, ...]
        #    rgb: tuple[Uint8, Uint8, Uint8]
        try:
            for validator in cparam.validator:  # pyright: ignore
                validator(type_, out)
        except (AssertionError, ValueError, TypeError) as e:
            raise ValidationError(exception_message=e.args[0] if e.args else "", value=out) from e

    return out


def convert(
    type_: Any,
    tokens: Union[Sequence[str], Sequence["Token"], NestedCliArgs],
    converter: Optional[Callable[[type, str], Any]] = None,
    name_transform: Optional[Callable[[str], str]] = None,
):
    """Coerce variables into a specified type.

    Internally used to coercing string CLI tokens into python builtin types.
    Externally, may be useful in a custom converter.
    See Cyclopt's automatic coercion rules :doc:`/rules`.

    If ``type_`` **is not** iterable, then each element of ``tokens`` will be converted independently.
    If there is more than one element, then the return type will be a ``Tuple[type_, ...]``.
    If there is a single element, then the return type will be ``type_``.

    If ``type_`` **is** iterable, then all elements of ``tokens`` will be collated.

    Parameters
    ----------
    type_: Type
        A type hint/annotation to coerce ``*args`` into.
    tokens: Union[Sequence[str], NestedCliArgs]
        String tokens to coerce.
        Generally, either a list of strings, or a dictionary of list of strings (recursive).
        Each leaf in the dictionary tree should be a list of strings.
    converter: Optional[Callable[[Type, str], Any]]
        An optional function to convert tokens to the inner-most types.
        The converter should have signature:

        .. code-block:: python

            def converter(type_: type, value: str) -> Any:
                "Perform conversion of string token."

        This allows to use the :func:`convert` function to handle the the difficult task
        of traversing lists/tuples/unions/etc, while leaving the final conversion logic to
        the caller.
    name_transform: Optional[Callable[[str], str]]
        Currently only used for ``Enum`` type hints.
        A function that transforms enum names and CLI values into a normalized format.

        The function should have signature:

        .. code-block:: python

            def name_transform(s: str) -> str:
                "Perform name transform."

        where the returned value is the name to be used on the CLI.

        If ``None``, defaults to ``cyclopts.default_name_transform``.

    Returns
    -------
    Any
        Coerced version of input ``*args``.
    """
    from cyclopts.argument import Token

    if not tokens:
        raise ValueError

    if not isinstance(tokens, dict) and isinstance(tokens[0], str):
        tokens = tuple(Token(value=str(x)) for x in tokens)

    if name_transform is None:
        name_transform = default_name_transform

    convert_priv = partial(_convert, converter=converter, name_transform=name_transform)
    convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform)
    type_ = resolve(type_)

    if type_ is Any:
        type_ = str

    type_ = _implicit_iterable_type_mapping.get(type_, type_)

    origin_type = get_origin(type_)
    maybe_origin_type = origin_type or type_

    if origin_type is tuple:
        return convert_tuple(type_, *tokens)  # pyright: ignore
    elif maybe_origin_type in ITERABLE_TYPES or origin_type is collections.abc.Iterable:
        return convert_priv(type_, tokens)  # pyright: ignore
    elif maybe_origin_type is dict:
        if not isinstance(tokens, dict):
            raise ValueError  # Programming error
        try:
            value_type = get_args(type_)[1]
        except IndexError:
            value_type = str
        dict_converted = {
            k: convert(value_type, v, converter=converter, name_transform=name_transform) for k, v in tokens.items()
        }
        return _converters.get(maybe_origin_type, maybe_origin_type)(**dict_converted)  # pyright: ignore
    elif isinstance(tokens, dict):
        raise ValueError(f"Dictionary of tokens provided for unknown {type_!r}.")  # Programming error
    else:
        if len(tokens) == 1:
            return convert_priv(type_, tokens[0])  # pyright: ignore
        tokens_per_element, _ = token_count(type_)
        if tokens_per_element == 1:
            return [convert_priv(type_, item) for item in tokens]  # pyright: ignore
        elif len(tokens) == tokens_per_element:
            return convert_priv(type_, tokens)  # pyright: ignore
        else:
            raise NotImplementedError("Unreachable?")


def token_count(type_: Any) -> tuple[int, bool]:
    """The number of tokens after a keyword the parameter should consume.

    Parameters
    ----------
    type_: Type
        A type hint/annotation to infer token_count from if not explicitly specified.

    Returns
    -------
    int
        Number of tokens to consume.
    bool
        If this is ``True`` and positional, consume all remaining tokens.
        The returned number of tokens constitutes a single element of the iterable-to-be-parsed.
    """
    type_ = resolve(type_)
    origin_type = get_origin(type_)

    if (origin_type or type_) is tuple:
        args = get_args(type_)
        if args:
            return sum(token_count(x)[0] for x in args if x is not ...), ... in args
        else:
            return 1, True
    elif (origin_type or type_) is bool:
        return 0, False
    elif type_ in ITERABLE_TYPES or (origin_type in ITERABLE_TYPES and len(get_args(type_)) == 0):
        return 1, True
    elif (origin_type in ITERABLE_TYPES or origin_type is collections.abc.Iterable) and len(get_args(type_)):
        return token_count(get_args(type_)[0])[0], True
    elif is_union(type_):
        sub_args = get_args(type_)
        token_count_target = token_count(sub_args[0])
        for sub_type_ in sub_args[1:]:
            this = token_count(sub_type_)
            if this != token_count_target:
                raise ValueError(
                    f"Cannot Union types that consume different numbers of tokens: {sub_args[0]} {sub_type_}"
                )
        return token_count_target
    elif is_builtin(type_):
        # Many builtins actually take in VAR_POSITIONAL when we really just want 1 argument.
        return 1, False
    else:
        # This is usually/always a custom user-defined class.
        field_infos = get_field_infos(type_)
        count, consume_all = 0, False
        for value in field_infos.values():
            if value.kind is value.VAR_POSITIONAL:
                consume_all = True
            elif not value.required:
                continue
            elem_count, elem_consume_all = token_count(value.hint)
            count += elem_count
            consume_all |= elem_consume_all

        # classes like ``Enum`` can slip through here with a 0 count.
        if not count:
            return 1, False

        return count, consume_all
