1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
import numpy as np
import pytest
import xarray as xr
from xarray.backends.chunks import align_nd_chunks, build_grid_chunks, grid_rechunk
from xarray.tests import requires_dask
@pytest.mark.parametrize(
"size, chunk_size, region, expected_chunks",
[
(10, 3, slice(1, 11), (2, 3, 3, 2)),
(10, 3, slice(None, None), (3, 3, 3, 1)),
(10, 3, None, (3, 3, 3, 1)),
(10, 3, slice(None, 10), (3, 3, 3, 1)),
(10, 3, slice(0, None), (3, 3, 3, 1)),
],
)
def test_build_grid_chunks(size, chunk_size, region, expected_chunks):
grid_chunks = build_grid_chunks(
size,
chunk_size=chunk_size,
region=region,
)
assert grid_chunks == expected_chunks
@pytest.mark.parametrize(
"nd_var_chunks, nd_backend_chunks, expected_chunks",
[
(((2, 2, 2, 2),), ((3, 3, 2),), ((3, 3, 2),)),
# ND cases
(((2, 4), (2, 3)), ((2, 2, 2), (3, 2)), ((2, 4), (3, 2))),
],
)
def test_align_nd_chunks(nd_var_chunks, nd_backend_chunks, expected_chunks):
aligned_nd_chunks = align_nd_chunks(
nd_var_chunks=nd_var_chunks,
nd_backend_chunks=nd_backend_chunks,
)
assert aligned_nd_chunks == expected_chunks
@requires_dask
@pytest.mark.parametrize(
"enc_chunks, region, nd_var_chunks, expected_chunks",
[
(
(3,),
(slice(2, 14),),
((6, 6),),
(
(
4,
6,
2,
),
),
),
(
(6,),
(slice(0, 13),),
((6, 7),),
(
(
6,
7,
),
),
),
((6,), (slice(0, 13),), ((6, 6, 1),), ((6, 6, 1),)),
((3,), (slice(2, 14),), ((1, 3, 2, 6),), ((1, 3, 6, 2),)),
((3,), (slice(2, 14),), ((2, 2, 2, 6),), ((4, 6, 2),)),
((3,), (slice(2, 14),), ((3, 1, 3, 5),), ((4, 3, 5),)),
((4,), (slice(1, 13),), ((1, 1, 1, 4, 3, 2),), ((3, 4, 4, 1),)),
((5,), (slice(4, 16),), ((5, 7),), ((6, 6),)),
# ND cases
(
(3, 6),
(slice(2, 14), slice(0, 13)),
((6, 6), (6, 7)),
(
(
4,
6,
2,
),
(
6,
7,
),
),
),
],
)
def test_grid_rechunk(enc_chunks, region, nd_var_chunks, expected_chunks):
dims = [f"dim_{i}" for i in range(len(region))]
coords = {
dim: list(range(r.start, r.stop)) for dim, r in zip(dims, region, strict=False)
}
shape = tuple(r.stop - r.start for r in region)
arr = xr.DataArray(
np.arange(np.prod(shape)).reshape(shape),
dims=dims,
coords=coords,
)
arr = arr.chunk(dict(zip(dims, nd_var_chunks, strict=False)))
result = grid_rechunk(
arr.variable,
enc_chunks=enc_chunks,
region=region,
)
assert result.chunks == expected_chunks
|