File: test_indexing.py

package info (click to toggle)
python-thinc 9.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,896 kB
  • sloc: python: 17,122; javascript: 1,559; ansic: 342; makefile: 15; sh: 13
file content (66 lines) | stat: -rw-r--r-- 1,896 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy
import pytest
from numpy.testing import assert_allclose

from thinc.types import Pairs, Ragged


@pytest.fixture
def ragged():
    data = numpy.zeros((20, 4), dtype="f")
    lengths = numpy.array([4, 2, 8, 1, 4], dtype="i")
    data[0] = 0
    data[1] = 1
    data[2] = 2
    data[3] = 3
    data[4] = 4
    data[5] = 5
    return Ragged(data, lengths)


def test_ragged_empty():
    data = numpy.zeros((0, 4), dtype="f")
    lengths = numpy.array([], dtype="i")
    ragged = Ragged(data, lengths)
    assert_allclose(ragged[0:0].data, ragged.data)
    assert_allclose(ragged[0:0].lengths, ragged.lengths)
    assert_allclose(ragged[0:2].data, ragged.data)
    assert_allclose(ragged[0:2].lengths, ragged.lengths)
    assert_allclose(ragged[1:2].data, ragged.data)
    assert_allclose(ragged[1:2].lengths, ragged.lengths)


def test_ragged_starts_ends(ragged):
    starts = ragged._get_starts()
    ends = ragged._get_ends()
    assert list(starts) == [0, 4, 6, 14, 15]
    assert list(ends) == [4, 6, 14, 15, 19]


def test_ragged_simple_index(ragged, i=1):
    r = ragged[i]
    assert_allclose(r.data, ragged.data[4:6])
    assert_allclose(r.lengths, ragged.lengths[i : i + 1])


def test_ragged_slice_index(ragged, start=0, end=2):
    r = ragged[start:end]
    size = ragged.lengths[start:end].sum()
    assert r.data.shape == (size, r.data.shape[1])
    assert_allclose(r.lengths, ragged.lengths[start:end])


def test_ragged_array_index(ragged):
    arr = numpy.array([2, 1, 4], dtype="i")
    r = ragged[arr]
    assert r.data.shape[0] == ragged.lengths[arr].sum()


def test_pairs_arrays():
    one = numpy.zeros((128, 45), dtype="f")
    two = numpy.zeros((128, 12), dtype="f")
    pairs = Pairs(one, two)
    assert pairs[:2].one.shape == (2, 45)
    assert pairs[0].two.shape == (12,)
    assert pairs[-1:].one.shape == (1, 45)
    assert pairs[-1:].two.shape == (1, 12)