1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
from __future__ import annotations
import functools
import traceback
import pytest
h5py = pytest.importorskip("h5py")
from dask.utils import tmpfile
from distributed.protocol import deserialize, serialize
def silence_h5py_issue775(func):
@functools.wraps(func)
def wrapper():
try:
func()
except RuntimeError as e:
# https://github.com/h5py/h5py/issues/775
if str(e) != "dictionary changed size during iteration":
raise
tb = traceback.extract_tb(e.__traceback__)
filename, lineno, _, _ = tb[-1]
if not filename.endswith("h5py/_objects.pyx"):
raise
return wrapper
@silence_h5py_issue775
def test_serialize_deserialize_file():
with tmpfile() as fn:
with h5py.File(fn, mode="a") as f:
f.create_dataset("/x", shape=(2, 2), dtype="i4")
with h5py.File(fn, mode="r") as f:
g = deserialize(*serialize(f))
assert f.filename == g.filename
assert isinstance(g, h5py.File)
assert f.mode == g.mode
assert g["x"].shape == (2, 2)
@silence_h5py_issue775
def test_serialize_deserialize_group():
with tmpfile() as fn:
with h5py.File(fn, mode="a") as f:
f.create_dataset("/group1/group2/x", shape=(2, 2), dtype="i4")
with h5py.File(fn, mode="r") as f:
group = f["/group1/group2"]
group2 = deserialize(*serialize(group))
assert isinstance(group2, h5py.Group)
assert group.file.filename == group2.file.filename
assert group2["x"].shape == (2, 2)
@silence_h5py_issue775
def test_serialize_deserialize_dataset():
with tmpfile() as fn:
with h5py.File(fn, mode="a") as f:
x = f.create_dataset("/group1/group2/x", shape=(2, 2), dtype="i4")
with h5py.File(fn, mode="r") as f:
x = f["group1/group2/x"]
y = deserialize(*serialize(x))
assert isinstance(y, h5py.Dataset)
assert x.name == y.name
assert x.file.filename == y.file.filename
assert (x[:] == y[:]).all()
@silence_h5py_issue775
def test_raise_error_on_serialize_write_permissions():
with tmpfile() as fn:
with h5py.File(fn, mode="a") as f:
x = f.create_dataset("/x", shape=(2, 2), dtype="i4")
f.flush()
with pytest.raises(TypeError):
deserialize(*serialize(x))
with pytest.raises(TypeError):
deserialize(*serialize(f))
import dask.array as da
from distributed.utils_test import gen_cluster
@silence_h5py_issue775
@gen_cluster(client=True)
async def test_h5py_serialize(c, s, a, b):
from dask.utils import SerializableLock
lock = SerializableLock("hdf5")
with tmpfile() as fn:
with h5py.File(fn, mode="a") as f:
x = f.create_dataset("/group/x", shape=(4,), dtype="i4", chunks=(2,))
x[:] = [1, 2, 3, 4]
with h5py.File(fn, mode="r") as f:
dset = f["/group/x"]
x = da.from_array(dset, chunks=dset.chunks, lock=lock)
y = c.compute(x)
y = await y
assert (y[:] == dset[:]).all()
@gen_cluster(client=True)
async def test_h5py_serialize_2(c, s, a, b):
with tmpfile() as fn:
with h5py.File(fn, mode="a") as f:
x = f.create_dataset("/group/x", shape=(12,), dtype="i4", chunks=(4,))
x[:] = [1, 2, 3, 4] * 3
with h5py.File(fn, mode="r") as f:
dset = f["/group/x"]
x = da.from_array(dset, chunks=(3,))
y = c.compute(x.sum())
y = await y
assert y == (1 + 2 + 3 + 4) * 3
|