File: raster_writer.py

package info (click to toggle)
python-rioxarray 0.19.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,304 kB
  • sloc: python: 7,893; makefile: 93
file content (316 lines) | stat: -rw-r--r-- 11,828 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
This module contains a dataset writer for Dask.

Credits:

RasterioWriter dask write functionality was adopted from https://github.com/dymaxionlabs/dask-rasterio  # noqa: E501
Source file:
- https://github.com/dymaxionlabs/dask-rasterio/blob/8dd7fdece7ad094a41908c0ae6b4fe6ca49cf5e1/dask_rasterio/write.py  # noqa: E501

"""
import numpy
import rasterio
from rasterio.windows import Window
from xarray.conventions import encode_cf_variable

from rioxarray._io import FILL_VALUE_NAMES, UNWANTED_RIO_ATTRS, _get_unsigned_dtype
from rioxarray.exceptions import RioXarrayError

try:
    import dask.array
    from dask import is_dask_collection
except ImportError:

    def is_dask_collection(_) -> bool:  # type: ignore
        """
        Replacement method to check if it is a dask collection
        """
        # if you cannot import dask, then it cannot be a dask array
        return False


# Note: transform & crs are removed in write_transform/write_crs


def _write_tags(*, raster_handle, tags):
    """
    Write tags to raster dataset
    """
    # filter out attributes that should be written in a different location
    skip_tags = (
        UNWANTED_RIO_ATTRS
        + FILL_VALUE_NAMES
        + (
            "crs",
            "transform",
            "scales",
            "scale_factor",
            "add_offset",
            "offsets",
            "grid_mapping",
        )
    )
    # this is for when multiple values are used
    # in this case, it will be stored in the raster description
    if not isinstance(tags.get("long_name"), str):
        skip_tags += ("long_name",)
    band_tags = tags.pop("band_tags", [])
    tags = {key: value for key, value in tags.items() if key not in skip_tags}
    raster_handle.update_tags(**tags)

    if isinstance(band_tags, list):
        for iii, band_tag in enumerate(band_tags):
            raster_handle.update_tags(iii + 1, **band_tag)


def _write_band_description(*, raster_handle, xarray_dataset):
    """
    Write band descriptions using the long name
    """
    long_name = xarray_dataset.attrs.get("long_name")
    if isinstance(long_name, (tuple, list)):
        if len(long_name) != raster_handle.count:
            raise RioXarrayError(
                "Number of names in the 'long_name' attribute does not equal "
                "the number of bands."
            )
        for iii, band_description in enumerate(long_name):
            raster_handle.set_band_description(iii + 1, band_description)
    else:
        band_description = long_name or xarray_dataset.name
        if band_description:
            for iii in range(raster_handle.count):
                raster_handle.set_band_description(iii + 1, band_description)


def _write_metatata_to_raster(*, raster_handle, xarray_dataset, tags):
    """
    Write the metadata stored in the xarray object to raster metadata
    """
    tags = (
        xarray_dataset.attrs.copy()
        if tags is None
        else {**xarray_dataset.attrs, **tags}
    )

    # write scales and offsets
    scales = tags.get("scales", xarray_dataset.encoding.get("scales"))
    if scales is None:
        scale_factor = tags.get(
            "scale_factor", xarray_dataset.encoding.get("scale_factor")
        )
        if scale_factor is not None:
            scales = (scale_factor,) * raster_handle.count
    if scales is not None:
        raster_handle.scales = scales

    offsets = tags.get("offsets", xarray_dataset.encoding.get("offsets"))
    if offsets is None:
        add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
        if add_offset is not None:
            offsets = (add_offset,) * raster_handle.count
    if offsets is not None:
        raster_handle.offsets = offsets

    _write_tags(raster_handle=raster_handle, tags=tags)
    _write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset)


def _ensure_nodata_dtype(*, original_nodata, new_dtype):
    """
    Convert the nodata to the new datatype and raise warning
    if the value of the nodata value changed.
    """
    # Complex-valued rasters can have real-valued nodata
    new_dtype = numpy.dtype(new_dtype)
    if numpy.issubdtype(new_dtype, numpy.complexfloating):
        nodata = original_nodata
    else:
        original_nodata = (
            float(original_nodata)
            if not numpy.issubdtype(type(original_nodata), numpy.integer)
            else original_nodata
        )
        failure_message = (
            f"Unable to convert nodata value ({original_nodata}) to "
            f"new dtype ({new_dtype})."
        )
        try:
            nodata = new_dtype.type(original_nodata)
        except OverflowError as error:
            raise OverflowError(failure_message) from error
        if not numpy.isnan(nodata) and original_nodata != nodata:
            raise OverflowError(failure_message)
    return nodata


def _get_dtypes(*, rasterio_dtype, encoded_rasterio_dtype, dataarray_dtype):
    """
    Determines the rasterio dtype and numpy dtypes based on
    the rasterio dtype and the encoded rasterio dtype.

    Parameters
    ----------
    rasterio_dtype: Union[str, numpy.dtype]
        The rasterio dtype to write to.
    encoded_rasterio_dtype: Union[str, numpy.dtype, None]
        The value of the original rasterio dtype in the encoding.
    dataarray_dtype: Union[str, numpy.dtype]
        The value of the dtype of the data array.

    Returns
    -------
    tuple[Union[str, numpy.dtype], Union[str, numpy.dtype]]:
        The rasterio dtype and numpy dtype.
    """
    # SCENARIO 1: User wants to write to complex_int16
    if rasterio_dtype == "complex_int16":
        numpy_dtype = "complex64"
    # SCENARIO 2: File originally in complext_int16 and dtype unchanged
    elif (
        rasterio_dtype is None
        and encoded_rasterio_dtype == "complex_int16"
        and str(dataarray_dtype) == "complex64"
    ):
        numpy_dtype = "complex64"
        rasterio_dtype = "complex_int16"
    # SCENARIO 3: rasterio dtype not provided
    elif rasterio_dtype is None:
        numpy_dtype = dataarray_dtype
        rasterio_dtype = dataarray_dtype
    # SCENARIO 4: rasterio dtype and numpy dtype are the same
    else:
        numpy_dtype = rasterio_dtype
    return rasterio_dtype, numpy_dtype


class RasterioWriter:
    """

    ..versionadded:: 0.2

    Rasterio wrapper to allow dask.array.store to do window saving or to
    save using the rasterio write method.
    """

    def __init__(self, raster_path):
        """
        raster_path: str
            The path to output the raster to.
        """
        # https://github.com/dymaxionlabs/dask-rasterio/issues/3#issuecomment-514781825
        # Rasterio datasets can't be pickled and can't be shared between
        # processes or threads. The work around is to distribute dataset
        # identifiers (paths or URIs) and then open them in new threads.
        # See mapbox/rasterio#1731.
        self.raster_path = raster_path

    def __setitem__(self, key, item):
        """Put the data chunk in the image"""
        if len(key) == 3:
            index_range, yyy, xxx = key
            indexes = list(
                range(
                    index_range.start + 1, index_range.stop + 1, index_range.step or 1
                )
            )
        else:
            indexes = 1
            yyy, xxx = key

        chy_off = yyy.start
        chy = yyy.stop - yyy.start
        chx_off = xxx.start
        chx = xxx.stop - xxx.start

        with rasterio.open(self.raster_path, "r+") as rds:
            rds.write(item, window=Window(chx_off, chy_off, chx, chy), indexes=indexes)

    def to_raster(self, *, xarray_dataarray, tags, windowed, lock, compute, **kwargs):
        """
        This method writes to the raster on disk.

        xarray_dataarray: xarray.DataArray
            The input data array to write to disk.
        tags: dict, optional
            A dictionary of tags to write to the raster.
        windowed: bool
            If True and the data array is not a dask array, it will write
            the data to disk using rasterio windows.
        lock: boolean or Lock, optional
            Lock to use to write data using dask.
            If not supplied, it will use a single process.
        compute: bool
            If True (default) and data is a dask array, then compute and save
            the data immediately. If False, return a dask Delayed object.
            Call ".compute()" on the Delayed object to compute the result
            later. Call ``dask.compute(delayed1, delayed2)`` to save
            multiple delayed files at once.
        dtype: numpy.dtype
            Numpy-compliant dtype used to save raster. If data is not already
            represented by this dtype in memory it is recast. dtype='complex_int16'
            is a special case to write in-memory numpy.complex64 to CInt16.
        **kwargs
            Keyword arguments to pass into writing the raster.
        """
        xarray_dataarray = xarray_dataarray.copy()
        kwargs["dtype"], numpy_dtype = _get_dtypes(
            rasterio_dtype=kwargs["dtype"],
            encoded_rasterio_dtype=xarray_dataarray.encoding.get("rasterio_dtype"),
            dataarray_dtype=xarray_dataarray.encoding.get(
                "dtype", str(xarray_dataarray.dtype)
            ),
        )
        # there is no equivalent for netCDF _Unsigned
        # across output GDAL formats. It is safest to convert beforehand.
        # https://github.com/OSGeo/gdal/issues/6352#issuecomment-1245981837
        if "_Unsigned" in xarray_dataarray.encoding:
            unsigned_dtype = _get_unsigned_dtype(
                unsigned=xarray_dataarray.encoding["_Unsigned"] == "true",
                dtype=numpy_dtype,
            )
            if unsigned_dtype is not None:
                numpy_dtype = unsigned_dtype
                kwargs["dtype"] = unsigned_dtype
                xarray_dataarray.encoding["rasterio_dtype"] = str(unsigned_dtype)
                xarray_dataarray.encoding["dtype"] = str(unsigned_dtype)

        if kwargs["nodata"] is not None:
            # Ensure dtype of output data matches the expected dtype.
            # This check is added here as the dtype of the data is
            # converted right before writing.
            kwargs["nodata"] = _ensure_nodata_dtype(
                original_nodata=kwargs["nodata"], new_dtype=numpy_dtype
            )

        with rasterio.open(self.raster_path, "w", **kwargs) as rds:
            _write_metatata_to_raster(
                raster_handle=rds, xarray_dataset=xarray_dataarray, tags=tags
            )
            if not (lock and is_dask_collection(xarray_dataarray.data)):
                # write data to raster immmediately if not dask array
                if windowed:
                    window_iter = rds.block_windows(1)
                else:
                    window_iter = [(None, None)]
                for _, window in window_iter:
                    if window is not None:
                        out_data = xarray_dataarray.rio.isel_window(window)
                    else:
                        out_data = xarray_dataarray
                    data = encode_cf_variable(out_data.variable).values.astype(
                        numpy_dtype
                    )
                    if data.ndim == 2:
                        rds.write(data, 1, window=window)
                    else:
                        rds.write(data, window=window)

        if lock and is_dask_collection(xarray_dataarray.data):
            return dask.array.store(
                encode_cf_variable(xarray_dataarray.variable).data.astype(numpy_dtype),
                self,
                lock=lock,
                compute=compute,
            )
        return None