1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
|
from __future__ import annotations
from importlib.util import find_spec
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import pytest
import zarr
from anndata import AnnData
from anndata.compat import DaskArray
from anndata.experimental import read_elem_lazy, read_lazy
from anndata.io import write_elem
from anndata.tests.helpers import (
GEN_ADATA_NO_XARRAY_ARGS,
AccessTrackingStore,
assert_equal,
gen_adata,
gen_typed_df,
)
from .conftest import ANNDATA_ELEMS
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from anndata._types import AnnDataElem
pytestmark = pytest.mark.skipif(not find_spec("xarray"), reason="xarray not installed")
@pytest.mark.parametrize(
("elem_key", "sub_key"),
[
("raw", "X"),
("obs", "cat"),
("obs", "int64"),
*((elem_name, None) for elem_name in ANNDATA_ELEMS),
],
)
def test_access_count_elem_access(
remote_store_tall_skinny: AccessTrackingStore,
adata_remote_tall_skinny: AnnData,
elem_key: AnnDataElem,
sub_key: str,
simple_subset_func: Callable[[AnnData], AnnData],
):
full_path = f"{elem_key}/{sub_key}" if sub_key is not None else elem_key
remote_store_tall_skinny.initialize_key_trackers({full_path, "X"})
# a series of methods that should __not__ read in any data
elem = getattr(simple_subset_func(adata_remote_tall_skinny), elem_key)
if sub_key is not None:
if elem_key in {"obs", "var"}:
elem[sub_key]
else:
getattr(elem, sub_key)
remote_store_tall_skinny.assert_access_count(full_path, 0)
remote_store_tall_skinny.assert_access_count("X", 0)
def test_access_count_subset(
remote_store_tall_skinny: AccessTrackingStore,
adata_remote_tall_skinny: AnnData,
):
non_obs_elem_names = filter(lambda e: e != "obs", ANNDATA_ELEMS)
remote_store_tall_skinny.initialize_key_trackers([
"obs/cat/codes",
*non_obs_elem_names,
])
adata_remote_tall_skinny[adata_remote_tall_skinny.obs["cat"] == "a", :]
# all codes read in for subset (from 4 chunks as set in the fixture)
remote_store_tall_skinny.assert_access_count("obs/cat/codes", 4)
for elem_name in non_obs_elem_names:
remote_store_tall_skinny.assert_access_count(elem_name, 0)
def test_access_count_subset_column_compute(
remote_store_tall_skinny: AccessTrackingStore,
adata_remote_tall_skinny: AnnData,
):
remote_store_tall_skinny.initialize_key_trackers(["obs/int64"])
adata_remote_tall_skinny[adata_remote_tall_skinny.shape[0] // 2, :].obs[
"int64"
].compute()
# two chunks needed for 0:10 subset
remote_store_tall_skinny.assert_access_count("obs/int64", 1)
def test_access_count_index(
remote_store_tall_skinny: AccessTrackingStore,
):
remote_store_tall_skinny.initialize_key_trackers(["obs/_index"])
read_lazy(remote_store_tall_skinny, load_annotation_index=False)
remote_store_tall_skinny.assert_access_count("obs/_index", 0)
read_lazy(remote_store_tall_skinny)
# 4 is number of chunks
remote_store_tall_skinny.assert_access_count("obs/_index", 4)
def test_access_count_dtype(
remote_store_tall_skinny: AccessTrackingStore,
adata_remote_tall_skinny: AnnData,
):
remote_store_tall_skinny.initialize_key_trackers(["obs/cat/categories"])
remote_store_tall_skinny.assert_access_count("obs/cat/categories", 0)
# This should only cause categories to be read in once
adata_remote_tall_skinny.obs["cat"].dtype # noqa: B018
adata_remote_tall_skinny.obs["cat"].dtype # noqa: B018
adata_remote_tall_skinny.obs["cat"].dtype # noqa: B018
remote_store_tall_skinny.assert_access_count("obs/cat/categories", 1)
def test_uns_uses_dask(adata_remote: AnnData):
assert isinstance(adata_remote.uns["nested"]["nested_further"]["array"], DaskArray)
def test_to_memory(adata_remote: AnnData, adata_orig: AnnData):
remote_to_memory = adata_remote.to_memory()
assert_equal(remote_to_memory, adata_orig)
def test_access_counts_obsm_df(tmp_path: Path):
adata = AnnData(
X=np.array(np.random.rand(100, 20)),
)
adata.obsm["df"] = pd.DataFrame(
{"col1": np.random.rand(100), "col2": np.random.rand(100)},
index=adata.obs_names,
)
adata.write_zarr(tmp_path)
store = AccessTrackingStore(tmp_path)
store.initialize_key_trackers(["obsm/df"])
read_lazy(store, load_annotation_index=False)
store.assert_access_count("obsm/df", 0)
def test_view_to_memory(adata_remote: AnnData, adata_orig: AnnData):
obs_cats = adata_orig.obs["obs_cat"].cat.categories
subset_obs = adata_orig.obs["obs_cat"] == obs_cats[0]
assert_equal(adata_orig[subset_obs, :], adata_remote[subset_obs, :].to_memory())
var_cats = adata_orig.var["var_cat"].cat.categories
subset_var = adata_orig.var["var_cat"] == var_cats[0]
assert_equal(adata_orig[:, subset_var], adata_remote[:, subset_var].to_memory())
def test_view_of_view_to_memory(adata_remote: AnnData, adata_orig: AnnData):
cats_obs = adata_orig.obs["obs_cat"].cat.categories
subset_obs = (adata_orig.obs["obs_cat"] == cats_obs[0]) | (
adata_orig.obs["obs_cat"] == cats_obs[1]
)
subsetted_adata = adata_orig[subset_obs, :]
subset_subset_obs = subsetted_adata.obs["obs_cat"] == cats_obs[1]
subsetted_subsetted_adata = subsetted_adata[subset_subset_obs, :]
assert_equal(
subsetted_subsetted_adata,
adata_remote[subset_obs, :][subset_subset_obs, :].to_memory(),
)
cats_var = adata_orig.var["var_cat"].cat.categories
subset_var = (adata_orig.var["var_cat"] == cats_var[0]) | (
adata_orig.var["var_cat"] == cats_var[1]
)
subsetted_adata = adata_orig[:, subset_var]
subset_subset_var = subsetted_adata.var["var_cat"] == cats_var[1]
subsetted_subsetted_adata = subsetted_adata[:, subset_subset_var]
assert_equal(
subsetted_subsetted_adata,
adata_remote[:, subset_var][:, subset_subset_var].to_memory(),
)
@pytest.mark.zarr_io
def test_unconsolidated(tmp_path: Path, mtx_format):
adata = gen_adata((10, 10), mtx_format, **GEN_ADATA_NO_XARRAY_ARGS)
orig_pth = tmp_path / "orig.zarr"
adata.write_zarr(orig_pth)
(orig_pth / ".zmetadata").unlink()
store = AccessTrackingStore(orig_pth)
store.initialize_key_trackers(["obs/.zgroup", ".zgroup"])
with pytest.warns(UserWarning, match=r"Did not read zarr as consolidated"):
remote = read_lazy(store)
remote_to_memory = remote.to_memory()
assert_equal(remote_to_memory, adata)
store.assert_access_count("obs/.zgroup", 1)
def test_h5_file_obj(tmp_path: Path):
adata = gen_adata((10, 10), **GEN_ADATA_NO_XARRAY_ARGS)
orig_pth = tmp_path / "adata.h5ad"
adata.write_h5ad(orig_pth)
remote = read_lazy(orig_pth)
assert remote.file.is_open
assert remote.filename == orig_pth
assert_equal(remote.to_memory(), adata)
@pytest.fixture(scope="session")
def df_group(tmp_path_factory) -> zarr.Group:
df = gen_typed_df(120)
path = tmp_path_factory.mktemp("foo.zarr")
g = zarr.open_group(path, mode="w", zarr_format=2)
write_elem(g, "foo", df, dataset_kwargs={"chunks": 25})
return zarr.open(path, mode="r")["foo"]
@pytest.mark.parametrize(
("chunks", "expected_chunks"),
[((1,), (1,)), ((-1,), (120,)), (None, (25,))],
ids=["small", "minus_one_uses_full", "none_uses_ondisk_chunking"],
)
def test_chunks_df(
tmp_path: Path,
chunks: tuple[int] | None,
expected_chunks: tuple[int],
df_group: zarr.Group,
):
ds = read_elem_lazy(df_group, chunks=chunks)
for k in ds:
if isinstance(arr := ds[k].data, DaskArray):
assert arr.chunksize == expected_chunks
|