File: conftest.py

package info (click to toggle)
python-rioxarray 0.19.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,304 kB
  • sloc: python: 7,893; makefile: 93
file content (123 lines) | stat: -rw-r--r-- 4,182 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from functools import partial

import pytest
import rasterio
import xarray
from numpy.testing import assert_almost_equal, assert_array_equal
from packaging import version

import rioxarray
from rioxarray.raster_array import UNWANTED_RIO_ATTRS

xarray.set_options(warn_for_unclosed_files=True)

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data")
TEST_INPUT_DATA_DIR = os.path.join(TEST_DATA_DIR, "input")
TEST_COMPARE_DATA_DIR = os.path.join(TEST_DATA_DIR, "compare")
GDAL_GE_36 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.0")
GDAL_GE_361 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.1")
GDAL_GE_364 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.4")


# xarray.testing.assert_equal(input_xarray, compare_xarray)
def _assert_attrs_equal(input_xr, compare_xr, decimal_precision):
    """check attrubutes that matter"""
    for attr in compare_xr.attrs:
        if attr == "transform":
            assert_almost_equal(
                tuple(input_xr.rio._cached_transform())[:6],
                compare_xr.attrs[attr][:6],
                decimal=decimal_precision,
            )
        elif (
            attr != "_FillValue"
            and attr
            not in UNWANTED_RIO_ATTRS
            + (
                "creation_date",
                "grid_mapping",
                "coordinates",
                "crs",
            )
            and "#" not in attr
        ):
            try:
                assert_almost_equal(
                    input_xr.attrs[attr],
                    compare_xr.attrs[attr],
                    decimal=decimal_precision,
                )
            except (TypeError, ValueError):
                assert input_xr.attrs[attr] == compare_xr.attrs[attr]


def _assert_xarrays_equal(
    input_xarray, compare_xarray, precision=7, skip_xy_check=False
):
    _assert_attrs_equal(input_xarray, compare_xarray, precision)
    if hasattr(input_xarray, "variables"):
        # check coordinates
        for coord in input_xarray.coords:
            if coord in "xy":
                if not skip_xy_check:
                    assert_almost_equal(
                        input_xarray[coord].values,
                        compare_xarray[coord].values,
                        decimal=precision,
                    )
            else:
                assert (
                    input_xarray[coord].values == compare_xarray[coord].values
                ).all()

        for var in input_xarray.rio.vars:
            try:
                _assert_xarrays_equal(
                    input_xarray[var], compare_xarray[var], precision=precision
                )
            except AssertionError:
                print(f"Error with variable {var}")
                raise
    else:
        try:
            assert_almost_equal(
                input_xarray.values, compare_xarray.values, decimal=precision
            )
        except AssertionError:
            where_diff = input_xarray.values != compare_xarray.values
            print(input_xarray.values[where_diff])
            print(compare_xarray.values[where_diff])
            raise
        _assert_attrs_equal(input_xarray, compare_xarray, precision)

        compare_fill_value = compare_xarray.attrs.get(
            "_FillValue", compare_xarray.encoding.get("_FillValue")
        )
        input_fill_value = input_xarray.attrs.get(
            "_FillValue", input_xarray.encoding.get("_FillValue")
        )
        assert_array_equal([input_fill_value], [compare_fill_value])
        assert input_xarray.rio.grid_mapping == compare_xarray.rio.grid_mapping
        for unwanted_attr in UNWANTED_RIO_ATTRS + ("crs", "transform"):
            assert unwanted_attr not in input_xarray.attrs


open_rasterio_engine = partial(xarray.open_dataset, engine="rasterio")


@pytest.fixture(
    params=[
        rioxarray.open_rasterio,
        open_rasterio_engine,
    ]
)
def open_rasterio(request):
    return request.param


def _ensure_dataset(rds):
    # https://github.com/OSGeo/gdal/issues/7695
    if GDAL_GE_364 and isinstance(rds, xarray.DataArray):
        rds = rds.to_dataset()
    return rds