File: util.py

package info (click to toggle)
python-thinc 9.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,896 kB
  • sloc: python: 17,122; javascript: 1,559; ansic: 342; makefile: 15; sh: 13
file content (121 lines) | stat: -rw-r--r-- 3,600 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import contextlib
import shutil
import tempfile
from pathlib import Path

import numpy
import pytest

from thinc.api import ArgsKwargs, Linear, Padded, Ragged
from thinc.util import has_cupy, is_cupy_array, is_numpy_array


@contextlib.contextmanager
def make_tempdir():
    d = Path(tempfile.mkdtemp())
    yield d
    shutil.rmtree(str(d))


def get_model(W_b_input, cls=Linear):
    W, b, input_ = W_b_input
    nr_out, nr_in = W.shape
    model = cls(nr_out, nr_in)
    model.set_param("W", W)
    model.set_param("b", b)
    model.initialize()
    return model


def get_shape(W_b_input):
    W, b, input_ = W_b_input
    return input_.shape[0], W.shape[0], W.shape[1]


def get_data_checker(inputs):
    if isinstance(inputs, Ragged):
        return assert_raggeds_match
    elif isinstance(inputs, Padded):
        return assert_paddeds_match
    elif isinstance(inputs, list):
        return assert_lists_match
    elif isinstance(inputs, tuple) and len(inputs) == 4:
        return assert_padded_data_match
    elif isinstance(inputs, tuple) and len(inputs) == 2:
        return assert_ragged_data_match
    else:
        return assert_arrays_match


def assert_arrays_match(X, Y):
    assert X.dtype == Y.dtype
    # Transformations are allowed to change last dimension, but not batch size.
    assert X.shape[0] == Y.shape[0]
    return True


def assert_lists_match(X, Y):
    assert isinstance(X, list)
    assert isinstance(Y, list)
    assert len(X) == len(Y)
    for x, y in zip(X, Y):
        assert_arrays_match(x, y)
    return True


def assert_raggeds_match(X, Y):
    assert isinstance(X, Ragged)
    assert isinstance(Y, Ragged)
    assert_arrays_match(X.lengths, Y.lengths)
    assert_arrays_match(X.data, Y.data)
    return True


def assert_paddeds_match(X, Y):
    assert isinstance(X, Padded)
    assert isinstance(Y, Padded)
    assert_arrays_match(X.size_at_t, Y.size_at_t)
    assert assert_arrays_match(X.lengths, Y.lengths)
    assert assert_arrays_match(X.indices, Y.indices)
    assert X.data.dtype == Y.data.dtype
    assert X.data.shape[1] == Y.data.shape[1]
    assert X.data.shape[0] == Y.data.shape[0]
    return True


def assert_padded_data_match(X, Y):
    return assert_paddeds_match(Padded(*X), Padded(*Y))


def assert_ragged_data_match(X, Y):
    return assert_raggeds_match(Ragged(*X), Ragged(*Y))


def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_):
    assert isinstance(Y, ArgsKwargs)
    assert len(Y.args) == n_args
    assert list(Y.kwargs.keys()) == kwargs_keys
    assert all(isinstance(arg, type_) for arg in Y.args)
    assert all(isinstance(arg, type_) for arg in Y.kwargs.values())
    dX = backprop(Y)

    def is_supported_backend_array(arr):
        return is_cupy_array(arr) or is_numpy_array(arr)

    input_type = type(data) if not isinstance(data, list) else tuple
    assert isinstance(dX, input_type) or is_supported_backend_array(dX)

    if isinstance(data, dict):
        assert list(dX.keys()) == kwargs_keys
        assert all(is_supported_backend_array(arr) for arr in dX.values())
    elif isinstance(data, (list, tuple)):
        assert isinstance(dX, tuple)
        assert all(is_supported_backend_array(arr) for arr in dX)
    elif isinstance(data, ArgsKwargs):
        assert len(dX.args) == n_args
        assert list(dX.kwargs.keys()) == kwargs_keys

        assert all(is_supported_backend_array(arg) for arg in dX.args)
        assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values())
    elif not isinstance(data, numpy.ndarray):
        pytest.fail(f"Bad data type: {dX}")