1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
from __future__ import annotations
import warnings
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
class AccessorRegistrationWarning(Warning):
"""Warning for conflicts in accessor registration."""
class _CachedAccessor:
"""Custom property-like object (descriptor) for caching accessors."""
def __init__(self, name, accessor):
self._name = name
self._accessor = accessor
def __get__(self, obj, cls):
if obj is None:
# we're accessing the attribute of the class, i.e., Dataset.geo
return self._accessor
# Use the same dict as @pandas.util.cache_readonly.
# It must be explicitly declared in obj.__slots__.
try:
cache = obj._cache
except AttributeError:
cache = obj._cache = {}
try:
return cache[self._name]
except KeyError:
pass
try:
accessor_obj = self._accessor(obj)
except AttributeError as err:
# __getattr__ on data object will swallow any AttributeErrors
# raised when initializing the accessor, so we need to raise as
# something else (GH933):
raise RuntimeError(f"error initializing {self._name!r} accessor.") from err
cache[self._name] = accessor_obj
return accessor_obj
def _register_accessor(name, cls):
def decorator(accessor):
if hasattr(cls, name):
warnings.warn(
f"registration of accessor {accessor!r} under name {name!r} for type {cls!r} is "
"overriding a preexisting attribute with the same name.",
AccessorRegistrationWarning,
stacklevel=2,
)
setattr(cls, name, _CachedAccessor(name, accessor))
return accessor
return decorator
def register_dataarray_accessor(name):
"""Register a custom accessor on xarray.DataArray objects.
Parameters
----------
name : str
Name under which the accessor should be registered. A warning is issued
if this name conflicts with a preexisting attribute.
See Also
--------
register_dataset_accessor
"""
return _register_accessor(name, DataArray)
def register_dataset_accessor(name):
"""Register a custom property on xarray.Dataset objects.
Parameters
----------
name : str
Name under which the accessor should be registered. A warning is issued
if this name conflicts with a preexisting attribute.
Examples
--------
In your library code:
>>> @xr.register_dataset_accessor("geo")
... class GeoAccessor:
... def __init__(self, xarray_obj):
... self._obj = xarray_obj
...
... @property
... def center(self):
... # return the geographic center point of this dataset
... lon = self._obj.latitude
... lat = self._obj.longitude
... return (float(lon.mean()), float(lat.mean()))
...
... def plot(self):
... # plot this array's data on a map, e.g., using Cartopy
... pass
...
Back in an interactive IPython session:
>>> ds = xr.Dataset(
... {"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)}
... )
>>> ds.geo.center
(10.0, 5.0)
>>> ds.geo.plot() # plots data on a map
See Also
--------
register_dataarray_accessor
"""
return _register_accessor(name, Dataset)
def register_datatree_accessor(name):
"""Register a custom accessor on DataTree objects.
Parameters
----------
name : str
Name under which the accessor should be registered. A warning is issued
if this name conflicts with a preexisting attribute.
See Also
--------
xarray.register_dataarray_accessor
xarray.register_dataset_accessor
"""
return _register_accessor(name, DataTree)
|