from __future__ import annotations

import typing
from collections.abc import Hashable, Mapping
from typing import Any, Generic

import pandas as pd

from xarray.core import utils
from xarray.core.common import _contains_datetime_like_objects
from xarray.core.indexing import map_index_queries
from xarray.core.types import T_Dataset
from xarray.core.variable import IndexVariable, Variable

if typing.TYPE_CHECKING:
    from xarray.core.dataset import Dataset


class _LocIndexer(Generic[T_Dataset]):
    __slots__ = ("dataset",)

    def __init__(self, dataset: T_Dataset):
        self.dataset = dataset

    def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
        if not utils.is_dict_like(key):
            raise TypeError("can only lookup dictionaries from Dataset.loc")
        return self.dataset.sel(key)

    def __setitem__(self, key, value) -> None:
        if not utils.is_dict_like(key):
            raise TypeError(
                "can only set locations defined by dictionaries from Dataset.loc."
                f" Got: {key}"
            )

        # set new values
        dim_indexers = map_index_queries(self.dataset, key).dim_indexers
        self.dataset[dim_indexers] = value


def as_dataset(obj: Any) -> Dataset:
    """Cast the given object to a Dataset.

    Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
    object is only created if the provided object is not already one.
    """
    from xarray.core.dataset import Dataset

    if hasattr(obj, "to_dataset"):
        obj = obj.to_dataset()
    if not isinstance(obj, Dataset):
        obj = Dataset(obj)
    return obj


def _get_virtual_variable(
    variables, key: Hashable, dim_sizes: Mapping | None = None
) -> tuple[Hashable, Hashable, Variable]:
    """Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
    objects (if possible)

    """
    from xarray.core.dataarray import DataArray

    if dim_sizes is None:
        dim_sizes = {}

    if key in dim_sizes:
        data = pd.Index(range(dim_sizes[key]), name=key)
        variable = IndexVariable((key,), data)
        return key, key, variable

    if not isinstance(key, str):
        raise KeyError(key)

    split_key = key.split(".", 1)
    if len(split_key) != 2:
        raise KeyError(key)

    ref_name, var_name = split_key
    ref_var = variables[ref_name]

    if _contains_datetime_like_objects(ref_var):
        ref_var = DataArray(ref_var)
        data = getattr(ref_var.dt, var_name).data
    else:
        data = getattr(ref_var, var_name).data
    virtual_var = Variable(ref_var.dims, data)

    return ref_name, var_name, virtual_var
