File: test_transforms.py

package info (click to toggle)
python-thinc 9.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,896 kB
  • sloc: python: 17,122; javascript: 1,559; ansic: 342; makefile: 15; sh: 13
file content (79 lines) | stat: -rw-r--r-- 1,976 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy
import pytest

from thinc.api import NumpyOps, Ragged, registry, strings2arrays

from ..util import get_data_checker


@pytest.fixture(params=[[], [(10, 2)], [(5, 3), (1, 3)], [(2, 3), (0, 3), (1, 3)]])
def shapes(request):
    return request.param


@pytest.fixture
def ops():
    return NumpyOps()


@pytest.fixture
def list_data(shapes):
    return [numpy.zeros(shape, dtype="f") for shape in shapes]


@pytest.fixture
def ragged_data(ops, list_data):
    lengths = numpy.array([len(x) for x in list_data], dtype="i")
    if not list_data:
        return Ragged(ops.alloc2f(0, 0), lengths)
    else:
        return Ragged(ops.flatten(list_data), lengths)


@pytest.fixture
def padded_data(ops, list_data):
    return ops.list2padded(list_data)


@pytest.fixture
def array_data(ragged_data):
    return ragged_data.data


def check_transform(transform, in_data, out_data):
    model = registry.resolve({"config": {"@layers": transform}})["config"]
    input_checker = get_data_checker(in_data)
    output_checker = get_data_checker(out_data)
    model.initialize(in_data, out_data)
    Y, backprop = model(in_data, is_train=True)
    output_checker(Y, out_data)
    dX = backprop(Y)
    input_checker(dX, in_data)


def test_list2array(list_data, array_data):
    check_transform("list2array.v1", list_data, array_data)


def test_list2ragged(list_data, ragged_data):
    check_transform("list2ragged.v1", list_data, ragged_data)


def test_list2padded(list_data, padded_data):
    check_transform("list2padded.v1", list_data, padded_data)


def test_ragged2list(ragged_data, list_data):
    check_transform("ragged2list.v1", ragged_data, list_data)


def test_padded2list(padded_data, list_data):
    check_transform("padded2list.v1", padded_data, list_data)


def test_strings2arrays():
    strings = ["hello", "world"]
    model = strings2arrays()
    Y, backprop = model.begin_update(strings)
    assert len(Y) == len(strings)
    assert backprop([]) == []