1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
|
"""
This module contains a dataset writer for Dask.
Credits:
RasterioWriter dask write functionality was adopted from https://github.com/dymaxionlabs/dask-rasterio # noqa: E501
Source file:
- https://github.com/dymaxionlabs/dask-rasterio/blob/8dd7fdece7ad094a41908c0ae6b4fe6ca49cf5e1/dask_rasterio/write.py # noqa: E501
"""
import numpy
import rasterio
from rasterio.windows import Window
from xarray.conventions import encode_cf_variable
from rioxarray._io import FILL_VALUE_NAMES, UNWANTED_RIO_ATTRS, _get_unsigned_dtype
from rioxarray.exceptions import RioXarrayError
try:
import dask.array
from dask import is_dask_collection
except ImportError:
def is_dask_collection(_) -> bool: # type: ignore
"""
Replacement method to check if it is a dask collection
"""
# if you cannot import dask, then it cannot be a dask array
return False
# Note: transform & crs are removed in write_transform/write_crs
def _write_tags(*, raster_handle, tags):
"""
Write tags to raster dataset
"""
# filter out attributes that should be written in a different location
skip_tags = (
UNWANTED_RIO_ATTRS
+ FILL_VALUE_NAMES
+ (
"crs",
"transform",
"scales",
"scale_factor",
"add_offset",
"offsets",
"grid_mapping",
)
)
# this is for when multiple values are used
# in this case, it will be stored in the raster description
if not isinstance(tags.get("long_name"), str):
skip_tags += ("long_name",)
band_tags = tags.pop("band_tags", [])
tags = {key: value for key, value in tags.items() if key not in skip_tags}
raster_handle.update_tags(**tags)
if isinstance(band_tags, list):
for iii, band_tag in enumerate(band_tags):
raster_handle.update_tags(iii + 1, **band_tag)
def _write_band_description(*, raster_handle, xarray_dataset):
"""
Write band descriptions using the long name
"""
long_name = xarray_dataset.attrs.get("long_name")
if isinstance(long_name, (tuple, list)):
if len(long_name) != raster_handle.count:
raise RioXarrayError(
"Number of names in the 'long_name' attribute does not equal "
"the number of bands."
)
for iii, band_description in enumerate(long_name):
raster_handle.set_band_description(iii + 1, band_description)
else:
band_description = long_name or xarray_dataset.name
if band_description:
for iii in range(raster_handle.count):
raster_handle.set_band_description(iii + 1, band_description)
def _write_metatata_to_raster(*, raster_handle, xarray_dataset, tags):
"""
Write the metadata stored in the xarray object to raster metadata
"""
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)
# write scales and offsets
scales = tags.get("scales", xarray_dataset.encoding.get("scales"))
if scales is None:
scale_factor = tags.get(
"scale_factor", xarray_dataset.encoding.get("scale_factor")
)
if scale_factor is not None:
scales = (scale_factor,) * raster_handle.count
if scales is not None:
raster_handle.scales = scales
offsets = tags.get("offsets", xarray_dataset.encoding.get("offsets"))
if offsets is None:
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
if add_offset is not None:
offsets = (add_offset,) * raster_handle.count
if offsets is not None:
raster_handle.offsets = offsets
_write_tags(raster_handle=raster_handle, tags=tags)
_write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset)
def _ensure_nodata_dtype(*, original_nodata, new_dtype):
"""
Convert the nodata to the new datatype and raise warning
if the value of the nodata value changed.
"""
# Complex-valued rasters can have real-valued nodata
new_dtype = numpy.dtype(new_dtype)
if numpy.issubdtype(new_dtype, numpy.complexfloating):
nodata = original_nodata
else:
original_nodata = (
float(original_nodata)
if not numpy.issubdtype(type(original_nodata), numpy.integer)
else original_nodata
)
failure_message = (
f"Unable to convert nodata value ({original_nodata}) to "
f"new dtype ({new_dtype})."
)
try:
nodata = new_dtype.type(original_nodata)
except OverflowError as error:
raise OverflowError(failure_message) from error
if not numpy.isnan(nodata) and original_nodata != nodata:
raise OverflowError(failure_message)
return nodata
def _get_dtypes(*, rasterio_dtype, encoded_rasterio_dtype, dataarray_dtype):
"""
Determines the rasterio dtype and numpy dtypes based on
the rasterio dtype and the encoded rasterio dtype.
Parameters
----------
rasterio_dtype: Union[str, numpy.dtype]
The rasterio dtype to write to.
encoded_rasterio_dtype: Union[str, numpy.dtype, None]
The value of the original rasterio dtype in the encoding.
dataarray_dtype: Union[str, numpy.dtype]
The value of the dtype of the data array.
Returns
-------
tuple[Union[str, numpy.dtype], Union[str, numpy.dtype]]:
The rasterio dtype and numpy dtype.
"""
# SCENARIO 1: User wants to write to complex_int16
if rasterio_dtype == "complex_int16":
numpy_dtype = "complex64"
# SCENARIO 2: File originally in complext_int16 and dtype unchanged
elif (
rasterio_dtype is None
and encoded_rasterio_dtype == "complex_int16"
and str(dataarray_dtype) == "complex64"
):
numpy_dtype = "complex64"
rasterio_dtype = "complex_int16"
# SCENARIO 3: rasterio dtype not provided
elif rasterio_dtype is None:
numpy_dtype = dataarray_dtype
rasterio_dtype = dataarray_dtype
# SCENARIO 4: rasterio dtype and numpy dtype are the same
else:
numpy_dtype = rasterio_dtype
return rasterio_dtype, numpy_dtype
class RasterioWriter:
"""
..versionadded:: 0.2
Rasterio wrapper to allow dask.array.store to do window saving or to
save using the rasterio write method.
"""
def __init__(self, raster_path):
"""
raster_path: str
The path to output the raster to.
"""
# https://github.com/dymaxionlabs/dask-rasterio/issues/3#issuecomment-514781825
# Rasterio datasets can't be pickled and can't be shared between
# processes or threads. The work around is to distribute dataset
# identifiers (paths or URIs) and then open them in new threads.
# See mapbox/rasterio#1731.
self.raster_path = raster_path
def __setitem__(self, key, item):
"""Put the data chunk in the image"""
if len(key) == 3:
index_range, yyy, xxx = key
indexes = list(
range(
index_range.start + 1, index_range.stop + 1, index_range.step or 1
)
)
else:
indexes = 1
yyy, xxx = key
chy_off = yyy.start
chy = yyy.stop - yyy.start
chx_off = xxx.start
chx = xxx.stop - xxx.start
with rasterio.open(self.raster_path, "r+") as rds:
rds.write(item, window=Window(chx_off, chy_off, chx, chy), indexes=indexes)
def to_raster(self, *, xarray_dataarray, tags, windowed, lock, compute, **kwargs):
"""
This method writes to the raster on disk.
xarray_dataarray: xarray.DataArray
The input data array to write to disk.
tags: dict, optional
A dictionary of tags to write to the raster.
windowed: bool
If True and the data array is not a dask array, it will write
the data to disk using rasterio windows.
lock: boolean or Lock, optional
Lock to use to write data using dask.
If not supplied, it will use a single process.
compute: bool
If True (default) and data is a dask array, then compute and save
the data immediately. If False, return a dask Delayed object.
Call ".compute()" on the Delayed object to compute the result
later. Call ``dask.compute(delayed1, delayed2)`` to save
multiple delayed files at once.
dtype: numpy.dtype
Numpy-compliant dtype used to save raster. If data is not already
represented by this dtype in memory it is recast. dtype='complex_int16'
is a special case to write in-memory numpy.complex64 to CInt16.
**kwargs
Keyword arguments to pass into writing the raster.
"""
xarray_dataarray = xarray_dataarray.copy()
kwargs["dtype"], numpy_dtype = _get_dtypes(
rasterio_dtype=kwargs["dtype"],
encoded_rasterio_dtype=xarray_dataarray.encoding.get("rasterio_dtype"),
dataarray_dtype=xarray_dataarray.encoding.get(
"dtype", str(xarray_dataarray.dtype)
),
)
# there is no equivalent for netCDF _Unsigned
# across output GDAL formats. It is safest to convert beforehand.
# https://github.com/OSGeo/gdal/issues/6352#issuecomment-1245981837
if "_Unsigned" in xarray_dataarray.encoding:
unsigned_dtype = _get_unsigned_dtype(
unsigned=xarray_dataarray.encoding["_Unsigned"] == "true",
dtype=numpy_dtype,
)
if unsigned_dtype is not None:
numpy_dtype = unsigned_dtype
kwargs["dtype"] = unsigned_dtype
xarray_dataarray.encoding["rasterio_dtype"] = str(unsigned_dtype)
xarray_dataarray.encoding["dtype"] = str(unsigned_dtype)
if kwargs["nodata"] is not None:
# Ensure dtype of output data matches the expected dtype.
# This check is added here as the dtype of the data is
# converted right before writing.
kwargs["nodata"] = _ensure_nodata_dtype(
original_nodata=kwargs["nodata"], new_dtype=numpy_dtype
)
with rasterio.open(self.raster_path, "w", **kwargs) as rds:
_write_metatata_to_raster(
raster_handle=rds, xarray_dataset=xarray_dataarray, tags=tags
)
if not (lock and is_dask_collection(xarray_dataarray.data)):
# write data to raster immmediately if not dask array
if windowed:
window_iter = rds.block_windows(1)
else:
window_iter = [(None, None)]
for _, window in window_iter:
if window is not None:
out_data = xarray_dataarray.rio.isel_window(window)
else:
out_data = xarray_dataarray
data = encode_cf_variable(out_data.variable).values.astype(
numpy_dtype
)
if data.ndim == 2:
rds.write(data, 1, window=window)
else:
rds.write(data, window=window)
if lock and is_dask_collection(xarray_dataarray.data):
return dask.array.store(
encode_cf_variable(xarray_dataarray.variable).data.astype(numpy_dtype),
self,
lock=lock,
compute=compute,
)
return None
|