1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
from __future__ import annotations
import re
from typing import TYPE_CHECKING
import h5py
import zarr
import anndata as ad
from anndata._io.zarr import open_write_group
from anndata.compat import CSArray, CSMatrix, ZarrGroup, is_zarr_v2
from anndata.experimental import read_dispatched, write_dispatched
from anndata.tests.helpers import assert_equal, gen_adata
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
def test_read_dispatched_w_regex(tmp_path: Path):
def read_only_axis_dfs(func, elem_name: str, elem, iospec):
if iospec.encoding_type == "anndata":
return func(elem)
elif re.match(r"^/((obs)|(var))?(/.*)?$", elem_name):
return func(elem)
else:
return None
adata = gen_adata((1000, 100))
z = open_write_group(tmp_path)
ad.io.write_elem(z, "/", adata)
# TODO: see https://github.com/zarr-developers/zarr-python/issues/2716
if not is_zarr_v2() and isinstance(z, ZarrGroup):
z = zarr.open(z.store)
expected = ad.AnnData(obs=adata.obs, var=adata.var)
actual = read_dispatched(z, read_only_axis_dfs)
assert_equal(expected, actual)
def test_read_dispatched_dask(tmp_path: Path):
import dask.array as da
def read_as_dask_array(func, elem_name: str, elem, iospec):
if iospec.encoding_type in {
"dataframe",
"csr_matrix",
"csc_matrix",
"awkward-array",
}:
# Preventing recursing inside of these types
return ad.io.read_elem(elem)
elif iospec.encoding_type == "array":
return da.from_zarr(elem)
else:
return func(elem)
adata = gen_adata((1000, 100))
z = open_write_group(tmp_path)
ad.io.write_elem(z, "/", adata)
# TODO: see https://github.com/zarr-developers/zarr-python/issues/2716
if not is_zarr_v2() and isinstance(z, ZarrGroup):
z = zarr.open(z.store)
dask_adata = read_dispatched(z, read_as_dask_array)
assert isinstance(dask_adata.layers["array"], da.Array)
assert isinstance(dask_adata.obsm["array"], da.Array)
assert isinstance(dask_adata.uns["nested"]["nested_further"]["array"], da.Array)
expected = ad.io.read_elem(z)
actual = dask_adata.to_memory(copy=False)
assert_equal(expected, actual)
def test_read_dispatched_null_case(tmp_path: Path):
adata = gen_adata((100, 100))
z = open_write_group(tmp_path)
ad.io.write_elem(z, "/", adata)
# TODO: see https://github.com/zarr-developers/zarr-python/issues/2716
if not is_zarr_v2() and isinstance(z, ZarrGroup):
z = zarr.open(z.store)
expected = ad.io.read_elem(z)
actual = read_dispatched(z, lambda _, __, x, **___: ad.io.read_elem(x))
assert_equal(expected, actual)
def test_write_dispatched_chunks(tmp_path: Path):
from itertools import chain, repeat
def determine_chunks(elem_shape, specified_chunks):
chunk_iterator = chain(specified_chunks, repeat(None))
return tuple(e if c is None else c for e, c in zip(elem_shape, chunk_iterator))
adata = gen_adata((1000, 100))
def write_chunked(func, store, k, elem, dataset_kwargs, iospec):
M, N = 13, 42
def set_copy(d, **kwargs):
d = dict(d)
d.update(kwargs)
return d
# TODO: Should the passed path be absolute?
path = "/" + store.path + "/" + k
if hasattr(elem, "shape") and not isinstance(
elem, CSMatrix | CSArray | ad.AnnData
):
if re.match(r"^/((X)|(layers)).*", path):
chunks = (M, N)
elif path.startswith("/obsp"):
chunks = (M, M)
elif path.startswith("/obs"):
chunks = (M,)
elif path.startswith("/varp"):
chunks = (N, N)
elif path.startswith("/var"):
chunks = (N,)
else:
chunks = dataset_kwargs.get("chunks", ())
func(
store,
k,
elem,
dataset_kwargs=set_copy(
dataset_kwargs, chunks=determine_chunks(elem.shape, chunks)
),
)
else:
func(store, k, elem, dataset_kwargs=dataset_kwargs)
z = open_write_group(tmp_path)
write_dispatched(z, "/", adata, callback=write_chunked)
def check_chunking(k: str, v: ZarrGroup | zarr.Array):
if (
not isinstance(v, zarr.Array)
or v.shape == ()
or any(k.endswith(x) for x in ("data", "indices", "indptr"))
):
return
if re.match(r"obs[mp]?/\w+", k):
assert v.chunks[0] == 13
elif re.match(r"var[mp]?/\w+", k):
assert v.chunks[0] == 42
if is_zarr_v2():
z.visititems(check_chunking)
else:
def visititems(
z: ZarrGroup, visitor: Callable[[str, ZarrGroup | zarr.Array], None]
) -> None:
for key in z:
maybe_group = z[key]
if isinstance(maybe_group, ZarrGroup):
visititems(maybe_group, visitor)
else:
visitor(key, maybe_group)
visititems(z, check_chunking)
def test_io_dispatched_keys(tmp_path: Path):
h5ad_write_keys = []
zarr_write_keys = []
h5ad_read_keys = []
zarr_read_keys = []
h5ad_path = tmp_path / "test.h5ad"
zarr_path = tmp_path / "test.zarr"
def h5ad_writer(func, store, k, elem, dataset_kwargs, iospec):
h5ad_write_keys.append(k if is_zarr_v2() else k.strip("/"))
func(store, k, elem, dataset_kwargs=dataset_kwargs)
def zarr_writer(func, store, k, elem, dataset_kwargs, iospec):
zarr_write_keys.append(
k if is_zarr_v2() else f"{store.name.strip('/')}/{k.strip('/')}".strip("/")
)
func(store, k, elem, dataset_kwargs=dataset_kwargs)
def h5ad_reader(func, elem_name: str, elem, iospec):
h5ad_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/"))
return func(elem)
def zarr_reader(func, elem_name: str, elem, iospec):
zarr_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/"))
return func(elem)
adata = gen_adata((50, 100))
with h5py.File(h5ad_path, "w") as f:
write_dispatched(f, "/", adata, callback=h5ad_writer)
_ = read_dispatched(f, h5ad_reader)
f = open_write_group(zarr_path)
write_dispatched(f, "/", adata, callback=zarr_writer)
_ = read_dispatched(f, zarr_reader)
assert sorted(h5ad_read_keys) == sorted(zarr_read_keys)
assert sorted(h5ad_write_keys) == sorted(zarr_write_keys)
for sub_sparse_key in ["data", "indices", "indptr"]:
assert f"/X/{sub_sparse_key}" not in h5ad_read_keys
assert f"/X/{sub_sparse_key}" not in h5ad_write_keys
|