File: test_backends_chunks.py

package info (click to toggle)
python-xarray 2025.11.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 11,740 kB
  • sloc: python: 118,074; makefile: 260; sh: 47
file content (116 lines) | stat: -rw-r--r-- 3,185 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import pytest

import xarray as xr
from xarray.backends.chunks import align_nd_chunks, build_grid_chunks, grid_rechunk
from xarray.tests import requires_dask


@pytest.mark.parametrize(
    "size, chunk_size, region, expected_chunks",
    [
        (10, 3, slice(1, 11), (2, 3, 3, 2)),
        (10, 3, slice(None, None), (3, 3, 3, 1)),
        (10, 3, None, (3, 3, 3, 1)),
        (10, 3, slice(None, 10), (3, 3, 3, 1)),
        (10, 3, slice(0, None), (3, 3, 3, 1)),
        (2, 10, slice(0, 3), (2,)),
        (4, 10, slice(7, 10), (3, 1)),
    ],
)
def test_build_grid_chunks(size, chunk_size, region, expected_chunks):
    grid_chunks = build_grid_chunks(
        size,
        chunk_size=chunk_size,
        region=region,
    )
    assert grid_chunks == expected_chunks


@pytest.mark.parametrize(
    "nd_v_chunks, nd_backend_chunks, expected_chunks",
    [
        (((2, 2, 2, 2),), ((3, 3, 2),), ((3, 3, 2),)),
        # ND cases
        (((2, 4), (2, 3)), ((2, 2, 2), (3, 2)), ((2, 4), (3, 2))),
    ],
)
def test_align_nd_chunks(nd_v_chunks, nd_backend_chunks, expected_chunks):
    aligned_nd_chunks = align_nd_chunks(
        nd_v_chunks=nd_v_chunks,
        nd_backend_chunks=nd_backend_chunks,
    )
    assert aligned_nd_chunks == expected_chunks


@requires_dask
@pytest.mark.parametrize(
    "enc_chunks, region, nd_v_chunks, expected_chunks",
    [
        (
            (3,),
            (slice(2, 14),),
            ((6, 6),),
            (
                (
                    4,
                    6,
                    2,
                ),
            ),
        ),
        (
            (6,),
            (slice(0, 13),),
            ((6, 7),),
            (
                (
                    6,
                    7,
                ),
            ),
        ),
        ((6,), (slice(0, 13),), ((6, 6, 1),), ((6, 6, 1),)),
        ((3,), (slice(2, 14),), ((1, 3, 2, 6),), ((1, 3, 6, 2),)),
        ((3,), (slice(2, 14),), ((2, 2, 2, 6),), ((4, 6, 2),)),
        ((3,), (slice(2, 14),), ((3, 1, 3, 5),), ((4, 3, 5),)),
        ((4,), (slice(1, 13),), ((1, 1, 1, 4, 3, 2),), ((3, 4, 4, 1),)),
        ((5,), (slice(4, 16),), ((5, 7),), ((6, 6),)),
        # ND cases
        (
            (3, 6),
            (slice(2, 14), slice(0, 13)),
            ((6, 6), (6, 7)),
            (
                (
                    4,
                    6,
                    2,
                ),
                (
                    6,
                    7,
                ),
            ),
        ),
    ],
)
def test_grid_rechunk(enc_chunks, region, nd_v_chunks, expected_chunks):
    dims = [f"dim_{i}" for i in range(len(region))]
    coords = {
        dim: list(range(r.start, r.stop)) for dim, r in zip(dims, region, strict=False)
    }
    shape = tuple(r.stop - r.start for r in region)
    arr = xr.DataArray(
        np.arange(np.prod(shape)).reshape(shape),
        dims=dims,
        coords=coords,
    )
    arr = arr.chunk(dict(zip(dims, nd_v_chunks, strict=False)))

    result = grid_rechunk(
        arr.variable,
        enc_chunks=enc_chunks,
        region=region,
    )
    assert result.chunks == expected_chunks