1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
|
from __future__ import annotations
from contextlib import AbstractContextManager, suppress
from typing import TYPE_CHECKING
import h5py
import numpy as np
import pandas as pd
import pytest
import zarr
import anndata as ad
from anndata._io.specs.registry import IORegistryError
from anndata._io.utils import report_read_key_on_error
from anndata.compat import _clean_uns
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
@pytest.mark.parametrize(
"group_fn",
[
pytest.param(lambda _: zarr.group(), id="zarr"),
pytest.param(lambda p: h5py.File(p / "test.h5", mode="a"), id="h5py"),
],
)
@pytest.mark.parametrize("nested", [True, False], ids=["nested", "root"])
def test_key_error(
*, tmp_path, group_fn: Callable[[Path], zarr.Group | h5py.Group], nested: bool
):
@report_read_key_on_error
def read_attr(_):
raise NotImplementedError()
group = group_fn(tmp_path)
with group if isinstance(group, AbstractContextManager) else suppress():
if nested:
group = group.create_group("nested")
path = "/nested"
else:
path = "/"
group["X"] = np.array([1, 2, 3])
group.create_group("group")
with pytest.raises(
NotImplementedError, match=rf"reading key 'X'.*from {path}$"
):
read_attr(group["X"])
with pytest.raises(
NotImplementedError, match=rf"reading key 'group'.*from {path}$"
):
read_attr(group["group"])
def test_write_error_info(diskfmt, tmp_path):
pth = tmp_path / f"failed_write.{diskfmt}"
write = lambda x: getattr(x, f"write_{diskfmt}")(pth)
# Assuming we don't define a writer for tuples
a = ad.AnnData(uns={"a": {"b": {"c": (1, 2, 3)}}})
with pytest.raises(
IORegistryError, match=r"Error raised while writing key 'c'.*to /uns/a/b"
):
write(a)
def test_clean_uns():
adata = ad.AnnData(
uns=dict(species_categories=["a", "b"]),
obs=pd.DataFrame({"species": [0, 1, 0]}, index=["a", "b", "c"]),
var=pd.DataFrame({"species": [0, 1, 0, 2]}, index=["a", "b", "c", "d"]),
)
_clean_uns(adata)
assert "species_categories" not in adata.uns
assert isinstance(adata.obs["species"].dtype, pd.CategoricalDtype)
assert adata.obs["species"].tolist() == ["a", "b", "a"]
# var’s categories were overwritten by obs’s,
# which we can detect here because var has too high codes
assert pd.api.types.is_integer_dtype(adata.var["species"])
@pytest.mark.parametrize(
"group_fn",
[
pytest.param(lambda _: zarr.group(), id="zarr"),
pytest.param(lambda p: h5py.File(p / "test.h5", mode="a"), id="h5py"),
],
)
def test_only_child_key_reported_on_failure(tmp_path, group_fn):
class Foo:
pass
group = group_fn(tmp_path)
# This regex checks that the pattern inside the (?!...) group does not exist in the string
# (?!...) is a negative lookahead
# (?s) enables the dot to match newlines
# https://stackoverflow.com/a/406408/130164 <- copilot suggested lol
pattern = r"(?s)^((?!Error raised while writing key '/?a').)*$"
with pytest.raises(IORegistryError, match=pattern):
ad.io.write_elem(group, "/", {"a": {"b": Foo()}})
ad.io.write_elem(group, "/", {"a": {"b": [1, 2, 3]}})
group["a/b"].attrs["encoding-type"] = "not a real encoding type"
with pytest.raises(IORegistryError, match=pattern):
ad.io.read_elem(group)
|