1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
|
"""Functions for converting to and from xarray objects
"""
from collections import Counter
import numpy as np
import pandas as pd
from .coding.times import CFDatetimeCoder, CFTimedeltaCoder
from .conventions import decode_cf
from .core import duck_array_ops
from .core.dataarray import DataArray
from .core.dtypes import get_fill_value
from .core.pycompat import dask_array_type
cdms2_ignored_attrs = {"name", "tileIndex"}
iris_forbidden_keys = {
"standard_name",
"long_name",
"units",
"bounds",
"axis",
"calendar",
"leap_month",
"leap_year",
"month_lengths",
"coordinates",
"grid_mapping",
"climatology",
"cell_methods",
"formula_terms",
"compress",
"missing_value",
"add_offset",
"scale_factor",
"valid_max",
"valid_min",
"valid_range",
"_FillValue",
}
cell_methods_strings = {
"point",
"sum",
"maximum",
"median",
"mid_range",
"minimum",
"mean",
"mode",
"standard_deviation",
"variance",
}
def encode(var):
return CFTimedeltaCoder().encode(CFDatetimeCoder().encode(var.variable))
def _filter_attrs(attrs, ignored_attrs):
"""Return attrs that are not in ignored_attrs"""
return {k: v for k, v in attrs.items() if k not in ignored_attrs}
def from_cdms2(variable):
"""Convert a cdms2 variable into an DataArray"""
values = np.asarray(variable)
name = variable.id
dims = variable.getAxisIds()
coords = {}
for axis in variable.getAxisList():
coords[axis.id] = DataArray(
np.asarray(axis),
dims=[axis.id],
attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs),
)
grid = variable.getGrid()
if grid is not None:
ids = [a.id for a in grid.getAxisList()]
for axis in grid.getLongitude(), grid.getLatitude():
if axis.id not in variable.getAxisIds():
coords[axis.id] = DataArray(
np.asarray(axis[:]),
dims=ids,
attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs),
)
attrs = _filter_attrs(variable.attributes, cdms2_ignored_attrs)
dataarray = DataArray(values, dims=dims, coords=coords, name=name, attrs=attrs)
return decode_cf(dataarray.to_dataset())[dataarray.name]
def to_cdms2(dataarray, copy=True):
"""Convert a DataArray into a cdms2 variable"""
# we don't want cdms2 to be a hard dependency
import cdms2
def set_cdms2_attrs(var, attrs):
for k, v in attrs.items():
setattr(var, k, v)
# 1D axes
axes = []
for dim in dataarray.dims:
coord = encode(dataarray.coords[dim])
axis = cdms2.createAxis(coord.values, id=dim)
set_cdms2_attrs(axis, coord.attrs)
axes.append(axis)
# Data
var = encode(dataarray)
cdms2_var = cdms2.createVariable(
var.values, axes=axes, id=dataarray.name, mask=pd.isnull(var.values), copy=copy
)
# Attributes
set_cdms2_attrs(cdms2_var, var.attrs)
# Curvilinear and unstructured grids
if dataarray.name not in dataarray.coords:
cdms2_axes = {}
for coord_name in set(dataarray.coords.keys()) - set(dataarray.dims):
coord_array = dataarray.coords[coord_name].to_cdms2()
cdms2_axis_cls = (
cdms2.coord.TransientAxis2D
if coord_array.ndim
else cdms2.auxcoord.TransientAuxAxis1D
)
cdms2_axis = cdms2_axis_cls(coord_array)
if cdms2_axis.isLongitude():
cdms2_axes["lon"] = cdms2_axis
elif cdms2_axis.isLatitude():
cdms2_axes["lat"] = cdms2_axis
if "lon" in cdms2_axes and "lat" in cdms2_axes:
if len(cdms2_axes["lon"].shape) == 2:
cdms2_grid = cdms2.hgrid.TransientCurveGrid(
cdms2_axes["lat"], cdms2_axes["lon"]
)
else:
cdms2_grid = cdms2.gengrid.AbstractGenericGrid(
cdms2_axes["lat"], cdms2_axes["lon"]
)
for axis in cdms2_grid.getAxisList():
cdms2_var.setAxis(cdms2_var.getAxisIds().index(axis.id), axis)
cdms2_var.setGrid(cdms2_grid)
return cdms2_var
def _pick_attrs(attrs, keys):
"""Return attrs with keys in keys list"""
return {k: v for k, v in attrs.items() if k in keys}
def _get_iris_args(attrs):
"""Converts the xarray attrs into args that can be passed into Iris"""
# iris.unit is deprecated in Iris v1.9
import cf_units
args = {"attributes": _filter_attrs(attrs, iris_forbidden_keys)}
args.update(_pick_attrs(attrs, ("standard_name", "long_name")))
unit_args = _pick_attrs(attrs, ("calendar",))
if "units" in attrs:
args["units"] = cf_units.Unit(attrs["units"], **unit_args)
return args
# TODO: Add converting bounds from xarray to Iris and back
def to_iris(dataarray):
"""Convert a DataArray into a Iris Cube"""
# Iris not a hard dependency
import iris
from iris.fileformats.netcdf import parse_cell_methods
dim_coords = []
aux_coords = []
for coord_name in dataarray.coords:
coord = encode(dataarray.coords[coord_name])
coord_args = _get_iris_args(coord.attrs)
coord_args["var_name"] = coord_name
axis = None
if coord.dims:
axis = dataarray.get_axis_num(coord.dims)
if coord_name in dataarray.dims:
try:
iris_coord = iris.coords.DimCoord(coord.values, **coord_args)
dim_coords.append((iris_coord, axis))
except ValueError:
iris_coord = iris.coords.AuxCoord(coord.values, **coord_args)
aux_coords.append((iris_coord, axis))
else:
iris_coord = iris.coords.AuxCoord(coord.values, **coord_args)
aux_coords.append((iris_coord, axis))
args = _get_iris_args(dataarray.attrs)
args["var_name"] = dataarray.name
args["dim_coords_and_dims"] = dim_coords
args["aux_coords_and_dims"] = aux_coords
if "cell_methods" in dataarray.attrs:
args["cell_methods"] = parse_cell_methods(dataarray.attrs["cell_methods"])
masked_data = duck_array_ops.masked_invalid(dataarray.data)
cube = iris.cube.Cube(masked_data, **args)
return cube
def _iris_obj_to_attrs(obj):
"""Return a dictionary of attrs when given a Iris object"""
attrs = {"standard_name": obj.standard_name, "long_name": obj.long_name}
if obj.units.calendar:
attrs["calendar"] = obj.units.calendar
if obj.units.origin != "1" and not obj.units.is_unknown():
attrs["units"] = obj.units.origin
attrs.update(obj.attributes)
return {k: v for k, v in attrs.items() if v is not None}
def _iris_cell_methods_to_str(cell_methods_obj):
"""Converts a Iris cell methods into a string"""
cell_methods = []
for cell_method in cell_methods_obj:
names = "".join(f"{n}: " for n in cell_method.coord_names)
intervals = " ".join(
f"interval: {interval}" for interval in cell_method.intervals
)
comments = " ".join(f"comment: {comment}" for comment in cell_method.comments)
extra = " ".join([intervals, comments]).strip()
if extra:
extra = f" ({extra})"
cell_methods.append(names + cell_method.method + extra)
return " ".join(cell_methods)
def _name(iris_obj, default="unknown"):
"""Mimicks `iris_obj.name()` but with different name resolution order.
Similar to iris_obj.name() method, but using iris_obj.var_name first to
enable roundtripping.
"""
return iris_obj.var_name or iris_obj.standard_name or iris_obj.long_name or default
def from_iris(cube):
"""Convert a Iris cube into an DataArray"""
import iris.exceptions
name = _name(cube)
if name == "unknown":
name = None
dims = []
for i in range(cube.ndim):
try:
dim_coord = cube.coord(dim_coords=True, dimensions=(i,))
dims.append(_name(dim_coord))
except iris.exceptions.CoordinateNotFoundError:
dims.append(f"dim_{i}")
if len(set(dims)) != len(dims):
duplicates = [k for k, v in Counter(dims).items() if v > 1]
raise ValueError(f"Duplicate coordinate name {duplicates}.")
coords = {}
for coord in cube.coords():
coord_attrs = _iris_obj_to_attrs(coord)
coord_dims = [dims[i] for i in cube.coord_dims(coord)]
if coord_dims:
coords[_name(coord)] = (coord_dims, coord.points, coord_attrs)
else:
coords[_name(coord)] = ((), coord.points.item(), coord_attrs)
array_attrs = _iris_obj_to_attrs(cube)
cell_methods = _iris_cell_methods_to_str(cube.cell_methods)
if cell_methods:
array_attrs["cell_methods"] = cell_methods
# Deal with iris 1.* and 2.*
cube_data = cube.core_data() if hasattr(cube, "core_data") else cube.data
# Deal with dask and numpy masked arrays
if isinstance(cube_data, dask_array_type):
from dask.array import ma as dask_ma
filled_data = dask_ma.filled(cube_data, get_fill_value(cube.dtype))
elif isinstance(cube_data, np.ma.MaskedArray):
filled_data = np.ma.filled(cube_data, get_fill_value(cube.dtype))
else:
filled_data = cube_data
dataarray = DataArray(
filled_data, coords=coords, name=name, attrs=array_attrs, dims=dims
)
decoded_ds = decode_cf(dataarray._to_temp_dataset())
return dataarray._from_temp_dataset(decoded_ds)
|