File: test_coordinate_transform.py

package info (click to toggle)
python-xarray 2026.01.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 13,676 kB
  • sloc: python: 120,278; makefile: 269
file content (135 lines) | stat: -rw-r--r-- 4,865 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Property tests comparing CoordinateTransformIndex to PandasIndex."""

import functools
import operator
from collections.abc import Hashable
from typing import Any

import numpy as np
import pytest

pytest.importorskip("hypothesis")

import hypothesis.strategies as st
from hypothesis import given

import xarray as xr
import xarray.testing.strategies as xrst
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.indexes import CoordinateTransformIndex
from xarray.testing import assert_equal

DATA_VAR_NAME = "_test_data_"


class IdentityTransform(CoordinateTransform):
    """Identity transform that returns dimension positions as coordinate labels."""

    def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
        return dim_positions

    def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
        return coord_labels

    def equals(
        self, other: CoordinateTransform, exclude: frozenset[Hashable] | None = None
    ) -> bool:
        if not isinstance(other, IdentityTransform):
            return False
        return self.dim_size == other.dim_size


def create_transform_da(sizes: dict[str, int]) -> xr.DataArray:
    """Create a DataArray with IdentityTransform CoordinateTransformIndex."""
    dims = list(sizes.keys())
    shape = tuple(sizes.values())
    data = np.arange(np.prod(shape)).reshape(shape)

    # Create dataset with transform index for each dimension
    ds = xr.Dataset({DATA_VAR_NAME: (dims, data)})
    indexes = [
        xr.Coordinates.from_xindex(
            CoordinateTransformIndex(
                IdentityTransform((dim,), {dim: size}, dtype=np.dtype(np.int64))
            )
        )
        for dim, size in sizes.items()
    ]
    coords = functools.reduce(operator.or_, indexes)
    return ds.assign_coords(coords).get(DATA_VAR_NAME)


def create_pandas_da(sizes: dict[str, int]) -> xr.DataArray:
    """Create a DataArray with standard PandasIndex (range index)."""
    shape = tuple(sizes.values())
    data = np.arange(np.prod(shape)).reshape(shape)
    coords = {dim: np.arange(size) for dim, size in sizes.items()}
    return xr.DataArray(
        data, dims=list(sizes.keys()), coords=coords, name=DATA_VAR_NAME
    )


@given(
    st.data(),
    xrst.dimension_sizes(min_dims=1, max_dims=3, min_side=1, max_side=5),
)
def test_basic_indexing(data, sizes):
    """Test basic indexing produces identical results for transform and pandas index."""
    pandas_da = create_pandas_da(sizes)
    transform_da = create_transform_da(sizes)
    idxr = data.draw(xrst.basic_indexers(sizes=sizes))
    pandas_result = pandas_da.isel(idxr)
    transform_result = transform_da.isel(idxr)
    # TODO: any indexed dim in pandas_result should be an indexed dim in transform_result
    # This requires us to return a new CoordinateTransformIndex from .isel.
    # for dim in pandas_result.xindexes:
    #     assert isinstance(transform_result.xindexes[dim], CoordinateTransformIndex)
    assert_equal(pandas_result, transform_result)

    # not supported today
    # pandas_result = pandas_da.sel(idxr)
    # transform_result = transform_da.sel(idxr)
    # assert_identical(pandas_result, transform_result)


@given(
    st.data(),
    xrst.dimension_sizes(min_dims=1, max_dims=3, min_side=1, max_side=5),
)
def test_outer_indexing(data, sizes):
    """Test outer indexing produces identical results for transform and pandas index."""
    pandas_da = create_pandas_da(sizes)
    transform_da = create_transform_da(sizes)
    idxr = data.draw(xrst.outer_array_indexers(sizes=sizes, min_dims=1))
    pandas_result = pandas_da.isel(idxr)
    transform_result = transform_da.isel(idxr)
    assert_equal(pandas_result, transform_result)

    label_idxr = {
        dim: np.arange(pandas_da.sizes[dim])[ind.data] for dim, ind in idxr.items()
    }
    pandas_result = pandas_da.sel(label_idxr)
    transform_result = transform_da.sel(label_idxr, method="nearest")
    assert_equal(pandas_result, transform_result)


@given(
    st.data(),
    xrst.dimension_sizes(min_dims=2, max_dims=3, min_side=1, max_side=5),
)
def test_vectorized_indexing(data, sizes):
    """Test vectorized indexing produces identical results for transform and pandas index."""
    pandas_da = create_pandas_da(sizes)
    transform_da = create_transform_da(sizes)
    idxr = data.draw(xrst.vectorized_indexers(sizes=sizes))
    pandas_result = pandas_da.isel(idxr)
    transform_result = transform_da.isel(idxr)
    assert_equal(pandas_result, transform_result)

    label_idxr = {
        dim: ind.copy(data=np.arange(pandas_da.sizes[dim])[ind.data])
        for dim, ind in idxr.items()
    }
    pandas_result = pandas_da.sel(label_idxr, method="nearest")
    transform_result = transform_da.sel(label_idxr, method="nearest")
    assert_equal(pandas_result, transform_result)