File: aggregate_numpy_ufunc.py

package info (click to toggle)
python-numpy-groupies 0.10.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 476 kB
  • sloc: python: 2,346; makefile: 12
file content (128 lines) | stat: -rw-r--r-- 3,591 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np

from .aggregate_numpy import _aggregate_base
from .utils import (
    aggregate_common_doc,
    aliasing,
    check_boolean,
    get_func,
    maxval,
    minimum_dtype,
    minimum_dtype_scalar,
    minval,
)


def _anynan(group_idx, a, size, fill_value, dtype=None):
    return _any(group_idx, np.isnan(a), size, fill_value=fill_value, dtype=dtype)


def _allnan(group_idx, a, size, fill_value, dtype=None):
    return _all(group_idx, np.isnan(a), size, fill_value=fill_value, dtype=dtype)


def _any(group_idx, a, size, fill_value, dtype=None):
    check_boolean(fill_value)
    ret = np.full(size, fill_value, dtype=bool)
    if fill_value:
        ret[group_idx] = False  # any-test should start from False
    np.logical_or.at(ret, group_idx, a)
    return ret


def _all(group_idx, a, size, fill_value, dtype=None):
    check_boolean(fill_value)
    ret = np.full(size, fill_value, dtype=bool)
    if not fill_value:
        ret[group_idx] = True  # all-test should start from True
    np.logical_and.at(ret, group_idx, a)
    return ret


def _sum(group_idx, a, size, fill_value, dtype=None):
    dtype = minimum_dtype_scalar(fill_value, dtype, a)
    ret = np.full(size, fill_value, dtype=dtype)
    if fill_value != 0:
        ret[group_idx] = 0  # sums should start at 0
    np.add.at(ret, group_idx, a)
    return ret


def _len(group_idx, a, size, fill_value, dtype=None):
    return _sum(group_idx, 1, size, fill_value, dtype=int)


def _prod(group_idx, a, size, fill_value, dtype=None):
    """Same as aggregate_numpy.py"""
    dtype = minimum_dtype_scalar(fill_value, dtype, a)
    ret = np.full(size, fill_value, dtype=dtype)
    if fill_value != 1:
        ret[group_idx] = 1  # product should start from 1
    np.multiply.at(ret, group_idx, a)
    return ret


def _min(group_idx, a, size, fill_value, dtype=None):
    """Same as aggregate_numpy.py"""
    dtype = minimum_dtype(fill_value, dtype or a.dtype)
    dmax = maxval(fill_value, dtype)
    ret = np.full(size, fill_value, dtype=dtype)
    if fill_value != dmax:
        ret[group_idx] = dmax  # min starts from maximum
    np.minimum.at(ret, group_idx, a)
    return ret


def _max(group_idx, a, size, fill_value, dtype=None):
    """Same as aggregate_numpy.py"""
    dtype = minimum_dtype(fill_value, dtype or a.dtype)
    dmin = minval(fill_value, dtype)
    ret = np.full(size, fill_value, dtype=dtype)
    if fill_value != dmin:
        ret[group_idx] = dmin  # max starts from minimum
    np.maximum.at(ret, group_idx, a)
    return ret


_impl_dict = dict(
    min=_min,
    max=_max,
    sum=_sum,
    prod=_prod,
    all=_all,
    any=_any,
    allnan=_allnan,
    anynan=_anynan,
    len=_len,
)


def aggregate(group_idx, a, func="sum", size=None, fill_value=0, order="C", dtype=None, axis=None, **kwargs):
    func = get_func(func, aliasing, _impl_dict)
    if not isinstance(func, str):
        raise NotImplementedError("No such ufunc available")
    return _aggregate_base(
        group_idx,
        a,
        size=size,
        fill_value=fill_value,
        order=order,
        dtype=dtype,
        func=func,
        axis=axis,
        _impl_dict=_impl_dict,
        **kwargs,
    )


aggregate.__doc__ = (
    """
    Unlike ``aggregate_numpy``, which in most cases does some custom
    optimisations, this version simply uses ``numpy``'s ``ufunc.at``.

    As of version 1.14 this gives fairly poor performance. There should
    normally be no need to use this version, it is intended to be used in
    testing and benchmarking only.
    """
    + aggregate_common_doc
)