File: test_compressed_2d.py

package info (click to toggle)
python-sparse 0.16.0a9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,948 kB
  • sloc: python: 9,959; makefile: 8; sh: 3
file content (131 lines) | stat: -rw-r--r-- 3,452 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sparse
from sparse import COO
from sparse.numba_backend._compressed.compressed import CSC, CSR, GCXS
from sparse.numba_backend._utils import assert_eq

import pytest

import numpy as np
import scipy.sparse
import scipy.stats


@pytest.fixture(scope="module", params=[CSR, CSC])
def cls(request):
    return request.param


@pytest.fixture(scope="module", params=["f8", "f4", "i8", "i4"])
def dtype(request):
    return request.param


@pytest.fixture(scope="module")
def random_sparse(cls, dtype, rng):
    if np.issubdtype(dtype, np.integer):

        def data_rvs(n):
            return rng.integers(-1000, 1000, n)

    else:
        data_rvs = None
    return cls(sparse.random((20, 30), density=0.25, data_rvs=data_rvs).astype(dtype))


@pytest.fixture(scope="module")
def random_sparse_small(cls, dtype, rng):
    if np.issubdtype(dtype, np.integer):

        def data_rvs(n):
            return rng.integers(-10, 10, n)

    else:
        data_rvs = None
    return cls(sparse.random((20, 20), density=0.25, data_rvs=data_rvs).astype(dtype))


def test_repr(random_sparse):
    cls = type(random_sparse).__name__

    str_repr = repr(random_sparse)
    assert cls in str_repr


def test_bad_constructor_input(cls):
    with pytest.raises(ValueError, match=r".*shape.*"):
        cls(arg="hello world")


@pytest.mark.parametrize("n", [0, 1, 3])
def test_bad_nd_input(cls, n):
    a = np.ones(shape=tuple(5 for _ in range(n)))
    with pytest.raises(ValueError, match=f"{n}-d"):
        cls(a)


@pytest.mark.parametrize("source_type", ["gcxs", "coo"])
def test_from_sparse(cls, source_type):
    gcxs = sparse.random((20, 30), density=0.25, format=source_type)
    result = cls(gcxs)

    assert_eq(result, gcxs)


@pytest.mark.parametrize("scipy_type", ["coo", "csr", "csc", "lil"])
@pytest.mark.parametrize("CLS", [CSR, CSC, GCXS])
def test_from_scipy_sparse(scipy_type, CLS, dtype):
    orig = scipy.sparse.random(20, 30, density=0.2, format=scipy_type, dtype=dtype)
    ref = COO.from_scipy_sparse(orig)
    result = CLS.from_scipy_sparse(orig)

    assert_eq(ref, result)

    result_via_init = CLS(orig)

    assert_eq(ref, result_via_init)


@pytest.mark.parametrize("cls_str", ["coo", "dok", "csr", "csc", "gcxs"])
def test_to_sparse(cls_str, random_sparse):
    result = random_sparse.asformat(cls_str)

    assert_eq(random_sparse, result)


@pytest.mark.parametrize("copy", [True, False])
def test_transpose(random_sparse, copy):
    from operator import is_, is_not

    t = random_sparse.transpose(copy=copy)
    tt = t.transpose(copy=copy)

    # Check if a copy was made
    check = is_not if copy else is_

    assert check(random_sparse.data, t.data)
    assert check(random_sparse.indices, t.indices)
    assert check(random_sparse.indptr, t.indptr)

    assert random_sparse.shape == t.shape[::-1]

    assert_eq(random_sparse, tt)
    assert type(random_sparse) == type(tt)

    assert_eq(random_sparse.transpose(axes=(0, 1)), random_sparse)
    assert_eq(random_sparse.transpose(axes=(1, 0)), t)
    with pytest.raises(ValueError, match="Invalid transpose axes"):
        random_sparse.transpose(axes=0)


def test_transpose_error(random_sparse):
    with pytest.raises(ValueError):
        random_sparse.transpose(axes=1)


def test_matmul(random_sparse_small):
    arr = random_sparse_small.todense()

    actual = random_sparse_small @ random_sparse_small
    expected = arr @ arr

    assert_eq(actual, expected)