1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
import importlib
import itertools
import operator
import os
import sparse
import pytest
import numpy as np
import scipy.sparse as sps
DENSITY = 0.001
def get_test_id(side):
return f"{side=}"
@pytest.fixture(params=[100, 500, 1000], ids=get_test_id)
def elemwise_args(request, seed, max_size):
side = request.param
if side**2 >= max_size:
pytest.skip()
rng = np.random.default_rng(seed=seed)
s1_sps = sps.random(side, side, format="csr", density=DENSITY, random_state=rng) * 10
s1_sps.sum_duplicates()
s2_sps = sps.random(side, side, format="csr", density=DENSITY, random_state=rng) * 10
s2_sps.sum_duplicates()
return s1_sps, s2_sps
def get_elemwise_id(param):
f, backend = param
return f"{f=}-{backend=}"
@pytest.fixture(
params=itertools.product([operator.add, operator.mul, operator.gt], ["SciPy", "Numba", "Finch"]),
scope="function",
ids=get_elemwise_id,
)
def backend(request):
f, backend = request.param
os.environ[sparse._ENV_VAR_NAME] = backend
importlib.reload(sparse)
yield f, sparse, backend
del os.environ[sparse._ENV_VAR_NAME]
importlib.reload(sparse)
def test_elemwise(benchmark, backend, elemwise_args):
s1_sps, s2_sps = elemwise_args
f, sparse, backend = backend
if backend == "SciPy":
s1 = s1_sps
s2 = s2_sps
elif backend == "Numba":
s1 = sparse.asarray(s1_sps)
s2 = sparse.asarray(s2_sps)
elif backend == "Finch":
s1 = sparse.asarray(s1_sps.asformat("csc"), format="csc")
s2 = sparse.asarray(s2_sps.asformat("csc"), format="csc")
f(s1, s2)
@benchmark
def bench():
f(s1, s2)
|