1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
import inspect
from xarray import Variable
from xarray.core.indexes import Index, PandasIndex
from pint_xarray import conversion
class PintIndex(Index):
def __init__(self, *, index, units):
"""create a unit-aware MetaIndex
Parameters
----------
index : xarray.Index
The wrapped index object.
units : mapping of hashable to unit-like
The units of the indexed coordinates
"""
if not isinstance(units, dict):
raise TypeError(
"Index units have to be a dict of coordinate names to units."
)
self.index = index
self.units = units
def _replace(self, new_index):
return self.__class__(index=new_index, units=self.units)
def create_variables(self, variables=None):
index_vars = self.index.create_variables(variables)
index_vars_units = {}
for name, var in index_vars.items():
data = conversion.array_attach_units(var.data, self.units[name])
var_units = Variable(var.dims, data, attrs=var.attrs, encoding=var.encoding)
index_vars_units[name] = var_units
return index_vars_units
@classmethod
def from_variables(cls, variables, options):
if len(variables) != 1:
raise ValueError("can only create a default index from single variables")
units = options.pop("units", None)
index = PandasIndex.from_variables(variables, options=options)
return cls(index=index, units={index.index.name: units})
@classmethod
def concat(cls, indexes, dim, positions):
raise NotImplementedError()
@classmethod
def stack(cls, variables, dim):
raise NotImplementedError()
def unstack(self):
raise NotImplementedError()
def sel(self, labels, **options):
converted_labels = conversion.convert_indexer_units(labels, self.units)
stripped_labels = conversion.strip_indexer_units(converted_labels)
return self.index.sel(stripped_labels, **options)
def isel(self, indexers):
subset = self.index.isel(indexers)
if subset is None:
return None
return self._replace(subset)
def join(self, other, how="inner"):
raise NotImplementedError()
def reindex_like(self, other):
raise NotImplementedError()
def equals(self, other, *, exclude=None):
if not isinstance(other, PintIndex):
return False
# for now we require exactly matching units to avoid the potentially
# expensive conversion
if self.units != other.units:
return False
# TODO:
# - remove try-except once we can drop xarray<2025.06.0
# - remove compat once we can require a version of xarray that completed
# the deprecation cycle
try:
from xarray.core.indexes import _wrap_index_equals
equals = _wrap_index_equals(self.index)
kwargs = {"exclude": exclude}
except ImportError: # pragma: no cover
equals = self.index.equals
signature = inspect.signature(self.index.equals)
if "exclude" in signature.parameters:
kwargs = {"exclude": exclude}
else:
kwargs = {}
# Last to avoid the potentially expensive comparison
return equals(other.index, **kwargs)
def roll(self, shifts):
return self._replace(self.index.roll(shifts))
def rename(self, name_dict, dims_dict):
return self._replace(self.index.rename(name_dict, dims_dict))
def __getitem__(self, indexer):
return self._replace(self.index[indexer])
def _repr_inline_(self, max_width):
name = self.__class__.__name__
wrapped_name = self.index.__class__.__name__
formatted_units = {n: f"{u:~P}" for n, u in self.units.items()}
return f"{name}({wrapped_name}, units={formatted_units})"
def __repr__(self):
formatted_units = {n: f"{u:~P}" for n, u in self.units.items()}
summary = f"<{self.__class__.__name__} (units={formatted_units})>"
return "\n".join([summary, repr(self.index)])
|