from __future__ import annotations

from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np

from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray
from xarray.namedarray.utils import is_duck_dask_array, module_available

if TYPE_CHECKING:
    from xarray.namedarray._typing import (
        T_Chunks,
        _DType_co,
        _NormalizedChunks,
        duckarray,
    )

    try:
        from dask.array import Array as DaskArray
    except ImportError:
        DaskArray = np.ndarray[Any, Any]


dask_available = module_available("dask")


class DaskManager(ChunkManagerEntrypoint["DaskArray"]):
    array_cls: type[DaskArray]
    available: bool = dask_available

    def __init__(self) -> None:
        # TODO can we replace this with a class attribute instead?

        from dask.array import Array

        self.array_cls = Array

    def is_chunked_array(self, data: duckarray[Any, Any]) -> bool:
        return is_duck_dask_array(data)

    def chunks(self, data: Any) -> _NormalizedChunks:
        return data.chunks  # type: ignore[no-any-return]

    def normalize_chunks(
        self,
        chunks: T_Chunks | _NormalizedChunks,
        shape: tuple[int, ...] | None = None,
        limit: int | None = None,
        dtype: _DType_co | None = None,
        previous_chunks: _NormalizedChunks | None = None,
    ) -> Any:
        """Called by open_dataset"""
        from dask.array.core import normalize_chunks

        return normalize_chunks(
            chunks,
            shape=shape,
            limit=limit,
            dtype=dtype,
            previous_chunks=previous_chunks,
        )  # type: ignore[no-untyped-call]

    def from_array(
        self, data: Any, chunks: T_Chunks | _NormalizedChunks, **kwargs: Any
    ) -> DaskArray | Any:
        import dask.array as da

        if isinstance(data, ImplicitToExplicitIndexingAdapter):
            # lazily loaded backend array classes should use NumPy array operations.
            kwargs["meta"] = np.ndarray

        return da.from_array(
            data,
            chunks,
            **kwargs,
        )  # type: ignore[no-untyped-call]

    def compute(
        self, *data: Any, **kwargs: Any
    ) -> tuple[np.ndarray[Any, _DType_co], ...]:
        from dask.array import compute

        return compute(*data, **kwargs)  # type: ignore[no-untyped-call, no-any-return]

    def persist(self, *data: Any, **kwargs: Any) -> tuple[DaskArray | Any, ...]:
        from dask import persist

        return persist(*data, **kwargs)  # type: ignore[no-untyped-call, no-any-return]

    @property
    def array_api(self) -> Any:
        from dask import array as da

        return da

    def reduction(
        self,
        arr: T_ChunkedArray,
        func: Callable[..., Any],
        combine_func: Callable[..., Any] | None = None,
        aggregate_func: Callable[..., Any] | None = None,
        axis: int | Sequence[int] | None = None,
        dtype: _DType_co | None = None,
        keepdims: bool = False,
    ) -> DaskArray | Any:
        from dask.array import reduction

        return reduction(
            arr,
            chunk=func,
            combine=combine_func,
            aggregate=aggregate_func,
            axis=axis,
            dtype=dtype,
            keepdims=keepdims,
        )  # type: ignore[no-untyped-call]

    def scan(
        self,
        func: Callable[..., Any],
        binop: Callable[..., Any],
        ident: float,
        arr: T_ChunkedArray,
        axis: int | None = None,
        dtype: _DType_co | None = None,
        **kwargs: Any,
    ) -> DaskArray | Any:
        from dask.array.reductions import cumreduction

        return cumreduction(
            func,
            binop,
            ident,
            arr,
            axis=axis,
            dtype=dtype,
            **kwargs,
        )  # type: ignore[no-untyped-call]

    def apply_gufunc(
        self,
        func: Callable[..., Any],
        signature: str,
        *args: Any,
        axes: Sequence[tuple[int, ...]] | None = None,
        axis: int | None = None,
        keepdims: bool = False,
        output_dtypes: Sequence[_DType_co] | None = None,
        output_sizes: dict[str, int] | None = None,
        vectorize: bool | None = None,
        allow_rechunk: bool = False,
        meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None,
        **kwargs: Any,
    ) -> Any:
        from dask.array.gufunc import apply_gufunc

        return apply_gufunc(
            func,
            signature,
            *args,
            axes=axes,
            axis=axis,
            keepdims=keepdims,
            output_dtypes=output_dtypes,
            output_sizes=output_sizes,
            vectorize=vectorize,
            allow_rechunk=allow_rechunk,
            meta=meta,
            **kwargs,
        )  # type: ignore[no-untyped-call]

    def map_blocks(
        self,
        func: Callable[..., Any],
        *args: Any,
        dtype: _DType_co | None = None,
        chunks: tuple[int, ...] | None = None,
        drop_axis: int | Sequence[int] | None = None,
        new_axis: int | Sequence[int] | None = None,
        **kwargs: Any,
    ) -> Any:
        from dask.array import map_blocks

        # pass through name, meta, token as kwargs
        return map_blocks(
            func,
            *args,
            dtype=dtype,
            chunks=chunks,
            drop_axis=drop_axis,
            new_axis=new_axis,
            **kwargs,
        )  # type: ignore[no-untyped-call]

    def blockwise(
        self,
        func: Callable[..., Any],
        out_ind: Iterable[Any],
        *args: Any,
        # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types
        name: str | None = None,
        token: Any | None = None,
        dtype: _DType_co | None = None,
        adjust_chunks: dict[Any, Callable[..., Any]] | None = None,
        new_axes: dict[Any, int] | None = None,
        align_arrays: bool = True,
        concatenate: bool | None = None,
        meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None,
        **kwargs: Any,
    ) -> DaskArray | Any:
        from dask.array import blockwise

        return blockwise(
            func,
            out_ind,
            *args,
            name=name,
            token=token,
            dtype=dtype,
            adjust_chunks=adjust_chunks,
            new_axes=new_axes,
            align_arrays=align_arrays,
            concatenate=concatenate,
            meta=meta,
            **kwargs,
        )  # type: ignore[no-untyped-call]

    def unify_chunks(
        self,
        *args: Any,  # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
        **kwargs: Any,
    ) -> tuple[dict[str, _NormalizedChunks], list[DaskArray]]:
        from dask.array.core import unify_chunks

        return unify_chunks(*args, **kwargs)  # type: ignore[no-any-return, no-untyped-call]

    def store(
        self,
        sources: Any | Sequence[Any],
        targets: Any,
        **kwargs: Any,
    ) -> Any:
        from dask.array import store

        return store(
            sources=sources,
            targets=targets,
            **kwargs,
        )

    def shuffle(
        self, x: DaskArray, indexer: list[list[int]], axis: int, chunks: T_Chunks
    ) -> DaskArray:
        import dask.array

        if not module_available("dask", minversion="2024.08.1"):
            raise ValueError(
                "This method is very inefficient on dask<2024.08.1. Please upgrade."
            )
        if chunks is None:
            chunks = "auto"
        if chunks != "auto":
            raise NotImplementedError("Only chunks='auto' is supported at present.")
        return dask.array.shuffle(x, indexer, axis, chunks="auto")
