1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
from __future__ import annotations
import typing
from collections.abc import Hashable, Mapping
from typing import Any, Generic
import pandas as pd
from xarray.core import utils
from xarray.core.common import _contains_datetime_like_objects
from xarray.core.indexing import map_index_queries
from xarray.core.types import T_Dataset
from xarray.core.variable import IndexVariable, Variable
if typing.TYPE_CHECKING:
from xarray.core.dataset import Dataset
class _LocIndexer(Generic[T_Dataset]):
__slots__ = ("dataset",)
def __init__(self, dataset: T_Dataset):
self.dataset = dataset
def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
if not utils.is_dict_like(key):
raise TypeError("can only lookup dictionaries from Dataset.loc")
return self.dataset.sel(key)
def __setitem__(self, key, value) -> None:
if not utils.is_dict_like(key):
raise TypeError(
"can only set locations defined by dictionaries from Dataset.loc."
f" Got: {key}"
)
# set new values
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
self.dataset[dim_indexers] = value
def as_dataset(obj: Any) -> Dataset:
"""Cast the given object to a Dataset.
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
object is only created if the provided object is not already one.
"""
from xarray.core.dataset import Dataset
if hasattr(obj, "to_dataset"):
obj = obj.to_dataset()
if not isinstance(obj, Dataset):
obj = Dataset(obj)
return obj
def _get_virtual_variable(
variables, key: Hashable, dim_sizes: Mapping | None = None
) -> tuple[Hashable, Hashable, Variable]:
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
objects (if possible)
"""
from xarray.core.dataarray import DataArray
if dim_sizes is None:
dim_sizes = {}
if key in dim_sizes:
data = pd.Index(range(dim_sizes[key]), name=key)
variable = IndexVariable((key,), data)
return key, key, variable
if not isinstance(key, str):
raise KeyError(key)
split_key = key.split(".", 1)
if len(split_key) != 2:
raise KeyError(key)
ref_name, var_name = split_key
ref_var = variables[ref_name]
if _contains_datetime_like_objects(ref_var):
ref_var = DataArray(ref_var)
data = getattr(ref_var.dt, var_name).data
else:
data = getattr(ref_var, var_name).data
virtual_var = Variable(ref_var.dims, data)
return ref_name, var_name, virtual_var
|