1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
|
import pint
import pytest
import xarray as xr
from pint_xarray import testing
unit_registry = pint.UnitRegistry(force_ndarray_like=True)
@pytest.mark.parametrize(
("a", "b", "error"),
(
pytest.param(
xr.DataArray(attrs={"units": "K"}),
xr.DataArray(attrs={"units": "K"}),
None,
id="equal attrs",
),
pytest.param(
xr.DataArray(attrs={"units": "m"}),
xr.DataArray(attrs={"units": "K"}),
AssertionError,
id="different attrs",
),
pytest.param(
xr.DataArray([10, 20] * unit_registry.K),
xr.DataArray([50, 80] * unit_registry.K),
None,
id="equal units",
),
pytest.param(
xr.DataArray([10, 20] * unit_registry.K),
xr.DataArray([50, 80] * unit_registry.dimensionless),
AssertionError,
id="different units",
),
pytest.param(
xr.Dataset({"a": ("x", [0, 10], {"units": "K"})}),
xr.Dataset({"a": ("x", [20, 40], {"units": "K"})}),
None,
id="matching variables",
),
pytest.param(
xr.Dataset({"a": ("x", [0, 10], {"units": "K"})}),
xr.Dataset({"b": ("x", [20, 40], {"units": "K"})}),
AssertionError,
id="mismatching variables",
),
),
)
def test_assert_units_equal(a, b, error):
if error is not None:
with pytest.raises(error):
testing.assert_units_equal(a, b)
return
testing.assert_units_equal(a, b)
|