File: test_io_partial.py

package info (click to toggle)
python-anndata 0.12.0~rc1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,704 kB
  • sloc: python: 19,721; makefile: 22; sh: 14
file content (100 lines) | stat: -rw-r--r-- 3,402 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import annotations

import warnings
from importlib.util import find_spec
from pathlib import Path

import h5py
import numpy as np
import pytest
import zarr
from scipy.sparse import csr_matrix

import anndata
from anndata import AnnData
from anndata._io.specs.registry import read_elem_partial
from anndata.io import read_elem, write_h5ad, write_zarr

X = np.array([[1.0, 0.0, 3.0], [4.0, 0.0, 6.0], [0.0, 8.0, 0.0]], dtype="float32")
X_check = np.array([[4.0, 0.0], [0.0, 8.0]], dtype="float32")

WRITER = dict(h5ad=write_h5ad, zarr=write_zarr)
READER = dict(h5ad=h5py.File, zarr=zarr.open)


@pytest.mark.parametrize("typ", [np.asarray, csr_matrix])
def test_read_partial_X(tmp_path, typ, diskfmt):
    adata = AnnData(X=typ(X))

    path = Path(tmp_path) / ("test_tp_X." + diskfmt)

    WRITER[diskfmt](path, adata)

    store = READER[diskfmt](path, mode="r")
    if diskfmt == "zarr":
        X_part = read_elem_partial(store["X"], indices=([1, 2], [0, 1]))
    else:
        # h5py doesn't allow fancy indexing across multiple dimensions
        X_part = read_elem_partial(store["X"], indices=([1, 2],))
        X_part = X_part[:, [0, 1]]
        store.close()

    assert np.all(X_check == X_part)


@pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed")
def test_read_partial_adata(tmp_path, diskfmt):
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", message=r"Importing read_.* from `anndata` is deprecated"
        )
        import scanpy as sc

    adata = sc.datasets.pbmc68k_reduced()
    # zarr v3 can't write recarray
    # https://github.com/zarr-developers/zarr-python/issues/2134
    if anndata.settings.zarr_write_format == 3 and isinstance(adata, AnnData):
        del adata.uns["rank_genes_groups"]["scores"]
        del adata.uns["rank_genes_groups"]["names"]

    path = Path(tmp_path) / ("test_rp." + diskfmt)

    WRITER[diskfmt](path, adata)

    storage = READER[diskfmt](path, mode="r")

    obs_idx = [1, 2]
    var_idx = [0, 3]
    adata_sbs = adata[obs_idx, var_idx]

    if diskfmt == "zarr":
        part = read_elem_partial(storage["X"], indices=(obs_idx, var_idx))
    else:
        # h5py doesn't allow fancy indexing across multiple dimensions
        part = read_elem_partial(storage["X"], indices=(obs_idx,))
        part = part[:, var_idx]
    assert np.all(part == adata_sbs.X)

    part = read_elem_partial(storage["obs"], indices=(obs_idx,))
    assert np.all(part.keys() == adata_sbs.obs.keys())
    assert np.all(part.index == adata_sbs.obs.index)

    part = read_elem_partial(storage["var"], indices=(var_idx,))
    assert np.all(part.keys() == adata_sbs.var.keys())
    assert np.all(part.index == adata_sbs.var.index)

    for key in storage["obsm"].keys():
        part = read_elem_partial(storage["obsm"][key], indices=(obs_idx,))
        assert np.all(part == adata_sbs.obsm[key])

    for key in storage["varm"].keys():
        part = read_elem_partial(storage["varm"][key], indices=(var_idx,))
        np.testing.assert_equal(part, adata_sbs.varm[key])

    for key in storage["obsp"].keys():
        part = read_elem_partial(storage["obsp"][key], indices=(obs_idx, obs_idx))
        part = part.toarray()
        assert np.all(part == adata_sbs.obsp[key])

    # check uns just in case
    np.testing.assert_equal(read_elem(storage["uns"]).keys(), adata.uns.keys())