File: test_backends_chunks.py

package info (click to toggle)
python-xarray 2025.08.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 11,796 kB
  • sloc: python: 115,416; makefile: 258; sh: 47
file content (114 lines) | stat: -rw-r--r-- 3,124 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import pytest

import xarray as xr
from xarray.backends.chunks import align_nd_chunks, build_grid_chunks, grid_rechunk
from xarray.tests import requires_dask


@pytest.mark.parametrize(
    "size, chunk_size, region, expected_chunks",
    [
        (10, 3, slice(1, 11), (2, 3, 3, 2)),
        (10, 3, slice(None, None), (3, 3, 3, 1)),
        (10, 3, None, (3, 3, 3, 1)),
        (10, 3, slice(None, 10), (3, 3, 3, 1)),
        (10, 3, slice(0, None), (3, 3, 3, 1)),
    ],
)
def test_build_grid_chunks(size, chunk_size, region, expected_chunks):
    grid_chunks = build_grid_chunks(
        size,
        chunk_size=chunk_size,
        region=region,
    )
    assert grid_chunks == expected_chunks


@pytest.mark.parametrize(
    "nd_var_chunks, nd_backend_chunks, expected_chunks",
    [
        (((2, 2, 2, 2),), ((3, 3, 2),), ((3, 3, 2),)),
        # ND cases
        (((2, 4), (2, 3)), ((2, 2, 2), (3, 2)), ((2, 4), (3, 2))),
    ],
)
def test_align_nd_chunks(nd_var_chunks, nd_backend_chunks, expected_chunks):
    aligned_nd_chunks = align_nd_chunks(
        nd_var_chunks=nd_var_chunks,
        nd_backend_chunks=nd_backend_chunks,
    )
    assert aligned_nd_chunks == expected_chunks


@requires_dask
@pytest.mark.parametrize(
    "enc_chunks, region, nd_var_chunks, expected_chunks",
    [
        (
            (3,),
            (slice(2, 14),),
            ((6, 6),),
            (
                (
                    4,
                    6,
                    2,
                ),
            ),
        ),
        (
            (6,),
            (slice(0, 13),),
            ((6, 7),),
            (
                (
                    6,
                    7,
                ),
            ),
        ),
        ((6,), (slice(0, 13),), ((6, 6, 1),), ((6, 6, 1),)),
        ((3,), (slice(2, 14),), ((1, 3, 2, 6),), ((1, 3, 6, 2),)),
        ((3,), (slice(2, 14),), ((2, 2, 2, 6),), ((4, 6, 2),)),
        ((3,), (slice(2, 14),), ((3, 1, 3, 5),), ((4, 3, 5),)),
        ((4,), (slice(1, 13),), ((1, 1, 1, 4, 3, 2),), ((3, 4, 4, 1),)),
        ((5,), (slice(4, 16),), ((5, 7),), ((6, 6),)),
        # ND cases
        (
            (3, 6),
            (slice(2, 14), slice(0, 13)),
            ((6, 6), (6, 7)),
            (
                (
                    4,
                    6,
                    2,
                ),
                (
                    6,
                    7,
                ),
            ),
        ),
    ],
)
def test_grid_rechunk(enc_chunks, region, nd_var_chunks, expected_chunks):
    dims = [f"dim_{i}" for i in range(len(region))]
    coords = {
        dim: list(range(r.start, r.stop)) for dim, r in zip(dims, region, strict=False)
    }
    shape = tuple(r.stop - r.start for r in region)
    arr = xr.DataArray(
        np.arange(np.prod(shape)).reshape(shape),
        dims=dims,
        coords=coords,
    )
    arr = arr.chunk(dict(zip(dims, nd_var_chunks, strict=False)))

    result = grid_rechunk(
        arr.variable,
        enc_chunks=enc_chunks,
        region=region,
    )
    assert result.chunks == expected_chunks