1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
import os
from functools import partial
import pytest
import rasterio
import xarray
from numpy.testing import assert_almost_equal, assert_array_equal
from packaging import version
import rioxarray
from rioxarray.raster_array import UNWANTED_RIO_ATTRS
xarray.set_options(warn_for_unclosed_files=True)
TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data")
TEST_INPUT_DATA_DIR = os.path.join(TEST_DATA_DIR, "input")
TEST_COMPARE_DATA_DIR = os.path.join(TEST_DATA_DIR, "compare")
GDAL_GE_36 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.0")
GDAL_GE_361 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.1")
GDAL_GE_364 = version.parse(rasterio.__gdal_version__) >= version.parse("3.6.4")
# xarray.testing.assert_equal(input_xarray, compare_xarray)
def _assert_attrs_equal(input_xr, compare_xr, decimal_precision):
"""check attrubutes that matter"""
for attr in compare_xr.attrs:
if attr == "transform":
assert_almost_equal(
tuple(input_xr.rio._cached_transform())[:6],
compare_xr.attrs[attr][:6],
decimal=decimal_precision,
)
elif (
attr != "_FillValue"
and attr
not in UNWANTED_RIO_ATTRS
+ (
"creation_date",
"grid_mapping",
"coordinates",
"crs",
)
and "#" not in attr
):
try:
assert_almost_equal(
input_xr.attrs[attr],
compare_xr.attrs[attr],
decimal=decimal_precision,
)
except (TypeError, ValueError):
assert input_xr.attrs[attr] == compare_xr.attrs[attr]
def _assert_xarrays_equal(
input_xarray, compare_xarray, precision=7, skip_xy_check=False
):
_assert_attrs_equal(input_xarray, compare_xarray, precision)
if hasattr(input_xarray, "variables"):
# check coordinates
for coord in input_xarray.coords:
if coord in "xy":
if not skip_xy_check:
assert_almost_equal(
input_xarray[coord].values,
compare_xarray[coord].values,
decimal=precision,
)
else:
assert (
input_xarray[coord].values == compare_xarray[coord].values
).all()
for var in input_xarray.rio.vars:
try:
_assert_xarrays_equal(
input_xarray[var], compare_xarray[var], precision=precision
)
except AssertionError:
print(f"Error with variable {var}")
raise
else:
try:
assert_almost_equal(
input_xarray.values, compare_xarray.values, decimal=precision
)
except AssertionError:
where_diff = input_xarray.values != compare_xarray.values
print(input_xarray.values[where_diff])
print(compare_xarray.values[where_diff])
raise
_assert_attrs_equal(input_xarray, compare_xarray, precision)
compare_fill_value = compare_xarray.attrs.get(
"_FillValue", compare_xarray.encoding.get("_FillValue")
)
input_fill_value = input_xarray.attrs.get(
"_FillValue", input_xarray.encoding.get("_FillValue")
)
assert_array_equal([input_fill_value], [compare_fill_value])
assert input_xarray.rio.grid_mapping == compare_xarray.rio.grid_mapping
for unwanted_attr in UNWANTED_RIO_ATTRS + ("crs", "transform"):
assert unwanted_attr not in input_xarray.attrs
open_rasterio_engine = partial(xarray.open_dataset, engine="rasterio")
@pytest.fixture(
params=[
rioxarray.open_rasterio,
open_rasterio_engine,
]
)
def open_rasterio(request):
return request.param
def _ensure_dataset(rds):
# https://github.com/OSGeo/gdal/issues/7695
if GDAL_GE_364 and isinstance(rds, xarray.DataArray):
rds = rds.to_dataset()
return rds
|