File: test_at.py

package info (click to toggle)
python-npx 0.1.6-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 168 kB
  • sloc: python: 330; makefile: 3
file content (35 lines) | stat: -rw-r--r-- 728 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np

import npx


def test_sum_at():
    a = [1.0, 2.0, 3.0]
    idx = [0, 1, 0]
    out = npx.sum_at(a, idx, minlength=4)

    tol = 1.0e-13
    ref = np.array([4.0, 2.0, 0.0, 0.0])
    assert np.all(np.abs(out - ref) < (1 + np.abs(ref)) * tol)


def test_add_at():
    a = [1.0, 2.0, 3.0]
    idx = [0, 1, 0]
    out = np.zeros(2)
    npx.add_at(out, idx, a)

    tol = 1.0e-13
    ref = np.array([4.0, 2.0])
    assert np.all(np.abs(out - ref) < (1 + np.abs(ref)) * tol)


def test_subtract_at():
    a = [1.0, 2.0, 3.0]
    idx = [0, 1, 0]
    out = np.ones(2)
    npx.subtract_at(out, idx, a)

    tol = 1.0e-13
    ref = np.array([-3.0, -1.0])
    assert np.all(np.abs(out - ref) < (1 + np.abs(ref)) * tol)