1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
|
import json
import numbers
from typing import Any
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from zarr.core.buffer import default_buffer_prototype
pytest.importorskip("hypothesis")
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
from hypothesis import assume, given, settings
from zarr.abc.store import Store
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
from zarr.core.sync import sync
from zarr.testing.strategies import (
array_metadata,
arrays,
basic_indices,
numpy_arrays,
orthogonal_indices,
simple_arrays,
stores,
zarr_formats,
)
def deep_equal(a: Any, b: Any) -> bool:
"""Deep equality check with handling of special cases for array metadata classes"""
if isinstance(a, (complex, np.complexfloating)) and isinstance(
b, (complex, np.complexfloating)
):
a_real, a_imag = float(a.real), float(a.imag)
b_real, b_imag = float(b.real), float(b.imag)
if np.isnan(a_real) and np.isnan(b_real):
real_eq = True
else:
real_eq = a_real == b_real
if np.isnan(a_imag) and np.isnan(b_imag):
imag_eq = True
else:
imag_eq = a_imag == b_imag
return real_eq and imag_eq
if isinstance(a, (float, np.floating)) and isinstance(b, (float, np.floating)):
if np.isnan(a) and np.isnan(b):
return True
return a == b
if isinstance(a, np.datetime64) and isinstance(b, np.datetime64):
if np.isnat(a) and np.isnat(b):
return True
return a == b
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
if a.shape != b.shape:
return False
return all(deep_equal(x, y) for x, y in zip(a.flat, b.flat, strict=False))
if isinstance(a, dict) and isinstance(b, dict):
if set(a.keys()) != set(b.keys()):
return False
return all(deep_equal(a[k], b[k]) for k in a)
if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
if len(a) != len(b):
return False
return all(deep_equal(x, y) for x, y in zip(a, b, strict=False))
return a == b
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(data=st.data())
def test_array_roundtrip(data: st.DataObject) -> None:
nparray = data.draw(numpy_arrays())
zarray = data.draw(arrays(arrays=st.just(nparray)))
assert_array_equal(nparray, zarray[:])
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(array=arrays())
def test_array_creates_implicit_groups(array):
path = array.path
ancestry = path.split("/")[:-1]
for i in range(len(ancestry)):
parent = "/".join(ancestry[: i + 1])
if array.metadata.zarr_format == 2:
assert (
sync(array.store.get(f"{parent}/.zgroup", prototype=default_buffer_prototype()))
is not None
)
elif array.metadata.zarr_format == 3:
assert (
sync(array.store.get(f"{parent}/zarr.json", prototype=default_buffer_prototype()))
is not None
)
# this decorator removes timeout; not ideal but it should avoid intermittent CI failures
@pytest.mark.asyncio
@settings(deadline=None)
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(data=st.data())
async def test_basic_indexing(data: st.DataObject) -> None:
zarray = data.draw(simple_arrays())
nparray = zarray[:]
indexer = data.draw(basic_indices(shape=nparray.shape))
# sync get
actual = zarray[indexer]
assert_array_equal(nparray[indexer], actual)
# async get
async_zarray = zarray._async_array
actual = await async_zarray.getitem(indexer)
assert_array_equal(nparray[indexer], actual)
# sync set
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
zarray[indexer] = new_data
nparray[indexer] = new_data
assert_array_equal(nparray, zarray[:])
# TODO test async setitem?
@pytest.mark.asyncio
@given(data=st.data())
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_oindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
nparray = zarray[:]
zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
# sync get
actual = zarray.oindex[zindexer]
assert_array_equal(nparray[npindexer], actual)
# async get
async_zarray = zarray._async_array
actual = await async_zarray.oindex.getitem(zindexer)
assert_array_equal(nparray[npindexer], actual)
# sync get
assume(zarray.shards is None) # GH2834
for idxr in npindexer:
if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size:
# behaviour of setitem with repeated indices is not guaranteed in practice
assume(False)
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
nparray[npindexer] = new_data
zarray.oindex[zindexer] = new_data
assert_array_equal(nparray, zarray[:])
# note: async oindex setitem not yet implemented
@pytest.mark.asyncio
@given(data=st.data())
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_vindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
nparray = zarray[:]
indexer = data.draw(
npst.integer_array_indices(
shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None)
)
)
# sync get
actual = zarray.vindex[indexer]
assert_array_equal(nparray[indexer], actual)
# async get
async_zarray = zarray._async_array
actual = await async_zarray.vindex.getitem(indexer)
assert_array_equal(nparray[indexer], actual)
# sync set
# FIXME!
# when the indexer is such that a value gets overwritten multiple times,
# I think the output depends on chunking.
# new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
# nparray[indexer] = new_data
# zarray.vindex[indexer] = new_data
# assert_array_equal(nparray, zarray[:])
# note: async vindex setitem not yet implemented
@given(store=stores, meta=array_metadata()) # type: ignore[misc]
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_roundtrip_array_metadata_from_store(
store: Store, meta: ArrayV2Metadata | ArrayV3Metadata
) -> None:
"""
Verify that the I/O for metadata in a store are lossless.
This test serializes an ArrayV2Metadata or ArrayV3Metadata object to a dict
of buffers via `to_buffer_dict`, writes each buffer to a store under keys
prefixed with "0/", and then reads them back. The test asserts that each
retrieved buffer exactly matches the original buffer.
"""
asdict = meta.to_buffer_dict(prototype=default_buffer_prototype())
for key, expected in asdict.items():
await store.set(f"0/{key}", expected)
actual = await store.get(f"0/{key}", prototype=default_buffer_prototype())
assert actual == expected
@given(data=st.data(), zarr_format=zarr_formats)
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: int) -> None:
"""
Verify that JSON serialization and deserialization of metadata is lossless.
For Zarr v2:
- The metadata is split into two JSON documents (one for array data and one
for attributes). The test merges the attributes back before deserialization.
For Zarr v3:
- All metadata is stored in a single JSON document. No manual merger is necessary.
The test then converts both the original and round-tripped metadata objects
into dictionaries using `dataclasses.asdict` and uses a deep equality check
to verify that the roundtrip has preserved all fields (including special
cases like NaN, Infinity, complex numbers, and datetime values).
"""
metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format)))
buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype())
if zarr_format == 2:
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode())
# zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict`
zarray_dict["attributes"] = zattrs_dict
metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict)
else:
zarray_dict = json.loads(buffer_dict[ZARR_JSON].to_bytes().decode())
metadata_roundtripped = ArrayV3Metadata.from_dict(zarray_dict)
orig = metadata.to_dict()
rt = metadata_roundtripped.to_dict()
assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}"
# @st.composite
# def advanced_indices(draw, *, shape):
# basic_idxr = draw(
# basic_indices(
# shape=shape, min_dims=len(shape), max_dims=len(shape), allow_ellipsis=False
# ).filter(lambda x: isinstance(x, tuple))
# )
# int_idxr = draw(
# npst.integer_array_indices(shape=shape, result_shape=npst.array_shapes(max_dims=1))
# )
# args = tuple(
# st.sampled_from((l, r)) for l, r in zip_longest(basic_idxr, int_idxr, fillvalue=slice(None))
# )
# return draw(st.tuples(*args))
# @given(st.data())
# def test_roundtrip_object_array(data):
# nparray = data.draw(np_arrays)
# zarray = data.draw(arrays(arrays=st.just(nparray)))
# assert_array_equal(nparray, zarray[:])
def serialized_complex_float_is_valid(
serialized: tuple[numbers.Real | str, numbers.Real | str],
) -> bool:
"""
Validate that the serialized representation of a complex float conforms to the spec.
The specification requires that a serialized complex float must be either:
- A JSON number, or
- One of the strings "NaN", "Infinity", or "-Infinity".
Args:
serialized: The value produced by JSON serialization for a complex floating point number.
Returns:
bool: True if the serialized value is valid according to the spec, False otherwise.
"""
return (
isinstance(serialized, tuple)
and len(serialized) == 2
and all(serialized_float_is_valid(x) for x in serialized)
)
def serialized_float_is_valid(serialized: numbers.Real | str) -> bool:
"""
Validate that the serialized representation of a float conforms to the spec.
The specification requires that a serialized float must be either:
- A JSON number, or
- One of the strings "NaN", "Infinity", or "-Infinity".
Args:
serialized: The value produced by JSON serialization for a floating point number.
Returns:
bool: True if the serialized value is valid according to the spec, False otherwise.
"""
if isinstance(serialized, numbers.Real):
return True
return serialized in ("NaN", "Infinity", "-Infinity")
@given(meta=array_metadata()) # type: ignore[misc]
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> None:
"""
Validate that the array metadata produced by the library conforms to the relevant spec (V2 vs V3).
For ArrayV2Metadata:
- Ensures that 'zarr_format' is 2.
- Verifies that 'filters' is either None or a tuple (and not an empty tuple).
For ArrayV3Metadata:
- Ensures that 'zarr_format' is 3.
For both versions:
- If the dtype is a floating point of some kind, verifies of fill values:
* NaN is serialized as the string "NaN"
* Positive Infinity is serialized as the string "Infinity"
* Negative Infinity is serialized as the string "-Infinity"
* Other fill values are preserved as-is.
- If the dtype is a complex number of some kind, verifies that each component of the fill
value (real and imaginary) satisfies the serialization rules for floating point numbers.
- If the dtype is a datetime of some kind, verifies that `NaT` values are serialized as "NaT".
Note:
This test validates spec-compliance for array metadata serialization.
It is a work-in-progress and should be expanded as further edge cases are identified.
"""
asdict_dict = meta.to_dict()
# version-specific validations
if isinstance(meta, ArrayV2Metadata):
assert asdict_dict["filters"] != ()
assert asdict_dict["filters"] is None or isinstance(asdict_dict["filters"], tuple)
assert asdict_dict["zarr_format"] == 2
else:
assert asdict_dict["zarr_format"] == 3
# version-agnostic validations
dtype_native = meta.dtype.to_native_dtype()
if dtype_native.kind == "f":
assert serialized_float_is_valid(asdict_dict["fill_value"])
elif dtype_native.kind == "c":
# fill_value should be a two-element array [real, imag].
assert serialized_complex_float_is_valid(asdict_dict["fill_value"])
elif dtype_native.kind in ("M", "m") and np.isnat(meta.fill_value):
assert asdict_dict["fill_value"] == -9223372036854775808
|