1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
"""
In this test, aggregate_numpy is taken as a reference implementation and this
results are compared against the results of the other implementations. Implementations
may throw NotImplementedError in order to show missing functionality without throwing
test errors.
"""
from itertools import product
import numpy as np
import pytest
from . import (
_impl_name,
_wrap_notimplemented_skip,
aggregate_numba,
aggregate_numpy,
aggregate_numpy_ufunc,
aggregate_pandas,
aggregate_purepy,
func_list,
)
class AttrDict(dict):
__getattr__ = dict.__getitem__
TEST_PAIRS = ["np/py", "ufunc/np", "numba/np", "pandas/np"]
@pytest.fixture(params=TEST_PAIRS, scope="module")
def aggregate_cmp(request, seed=100):
test_pair = request.param
if test_pair == "np/py":
# Some functions in purepy are not implemented
func_ref = _wrap_notimplemented_skip(aggregate_purepy.aggregate)
func = aggregate_numpy.aggregate
group_cnt = 100
else:
group_cnt = 1000
func_ref = aggregate_numpy.aggregate
if "ufunc" in request.param:
impl = aggregate_numpy_ufunc
elif "numba" in request.param:
impl = aggregate_numba
elif "pandas" in request.param:
impl = aggregate_pandas
else:
impl = None
if not impl:
pytest.skip("Implementation not available")
name = _impl_name(impl)
func = _wrap_notimplemented_skip(impl.aggregate, "aggregate_" + name)
rnd = np.random.RandomState(seed=seed)
# Gives 100000 duplicates of size 10 each
group_idx = np.repeat(np.arange(group_cnt), 2)
rnd.shuffle(group_idx)
group_idx = np.repeat(group_idx, 10)
a = rnd.randn(group_idx.size)
nana = a.copy()
nana[::3] = np.nan
nana[: (len(nana) // 2)] = np.nan
somea = a.copy()
somea[somea < 0.3] = 0
somea[::31] = np.nan
return AttrDict(locals())
def _deselect_purepy(aggregate_cmp, *args, **kwargs):
# purepy implementation does not handle ndim arrays
# This is a won't fix and should be deselected instead of skipped
return aggregate_cmp.endswith("py")
def _deselect_purepy_nanfuncs(aggregate_cmp, func, *args, **kwargs):
# purepy implementation does not handle nan values correctly
# This is a won't fix and should be deselected instead of skipped
return "nan" in getattr(func, "__name__", func) and aggregate_cmp.endswith("py")
def func_arbitrary(iterator):
tmp = 0
for x in iterator:
tmp += x * x
return tmp
def func_preserve_order(iterator):
tmp = 0
for i, x in enumerate(iterator, 1):
tmp += x**i
return tmp
@pytest.mark.filterwarnings("ignore:numpy.ufunc size changed")
@pytest.mark.deselect_if(func=_deselect_purepy_nanfuncs)
@pytest.mark.parametrize("fill_value", [0, 1, np.nan])
@pytest.mark.parametrize("func", func_list, ids=lambda x: getattr(x, "__name__", x))
def test_cmp(aggregate_cmp, func, fill_value, decimal=10):
is_nanfunc = "nan" in getattr(func, "__name__", func)
a = aggregate_cmp.nana if is_nanfunc else aggregate_cmp.a
try:
ref = aggregate_cmp.func_ref(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
except ValueError:
with pytest.raises(ValueError):
aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
else:
try:
res = aggregate_cmp.func(aggregate_cmp.group_idx, a, func=func, fill_value=fill_value)
except ValueError:
if np.isnan(fill_value) and aggregate_cmp.test_pair.endswith("py"):
pytest.skip(
"pure python version uses lists and does not raise ValueErrors when inserting nan into integers"
)
else:
raise
if isinstance(ref, np.ndarray):
assert res.dtype == ref.dtype
try:
np.testing.assert_allclose(res, ref, rtol=10**-decimal)
except AssertionError:
if "arg" in func and aggregate_cmp.test_pair.startswith("pandas"):
pytest.skip("pandas doesn't fill indices for all-nan groups with fill_value, but with -inf instead")
else:
raise
@pytest.mark.deselect_if(func=_deselect_purepy)
@pytest.mark.parametrize(["ndim", "order"], product([2, 3], ["C", "F"]))
def test_cmp_ndim(aggregate_cmp, ndim, order, outsize=100, decimal=14):
nindices = int(outsize**ndim)
outshape = tuple([outsize] * ndim)
group_idx = np.random.randint(0, outsize, size=(ndim, nindices))
a = np.random.random(group_idx.shape[1])
res = aggregate_cmp.func(group_idx, a, size=outshape, order=order)
ref = aggregate_cmp.func_ref(group_idx, a, size=outshape, order=order)
if ndim > 1 and order == "F":
# 1d arrays always return False here
assert np.isfortran(res)
else:
assert not np.isfortran(res)
assert res.shape == outshape
np.testing.assert_array_almost_equal(res, ref, decimal=decimal)
|