File: util.py

package info (click to toggle)
python-thinc 8.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 5,804 kB
  • sloc: python: 15,818; javascript: 1,554; ansic: 342; makefile: 20; sh: 13
file content (119 lines) | stat: -rw-r--r-- 3,598 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import contextlib
from pathlib import Path
import tempfile
import shutil
from thinc.api import Linear, Ragged, Padded, ArgsKwargs
import numpy
import pytest
from thinc.util import has_cupy, is_cupy_array, is_numpy_array


@contextlib.contextmanager
def make_tempdir():
    d = Path(tempfile.mkdtemp())
    yield d
    shutil.rmtree(str(d))


def get_model(W_b_input, cls=Linear):
    W, b, input_ = W_b_input
    nr_out, nr_in = W.shape
    model = cls(nr_out, nr_in)
    model.set_param("W", W)
    model.set_param("b", b)
    model.initialize()
    return model


def get_shape(W_b_input):
    W, b, input_ = W_b_input
    return input_.shape[0], W.shape[0], W.shape[1]


def get_data_checker(inputs):
    if isinstance(inputs, Ragged):
        return assert_raggeds_match
    elif isinstance(inputs, Padded):
        return assert_paddeds_match
    elif isinstance(inputs, list):
        return assert_lists_match
    elif isinstance(inputs, tuple) and len(inputs) == 4:
        return assert_padded_data_match
    elif isinstance(inputs, tuple) and len(inputs) == 2:
        return assert_ragged_data_match
    else:
        return assert_arrays_match


def assert_arrays_match(X, Y):
    assert X.dtype == Y.dtype
    # Transformations are allowed to change last dimension, but not batch size.
    assert X.shape[0] == Y.shape[0]
    return True


def assert_lists_match(X, Y):
    assert isinstance(X, list)
    assert isinstance(Y, list)
    assert len(X) == len(Y)
    for x, y in zip(X, Y):
        assert_arrays_match(x, y)
    return True


def assert_raggeds_match(X, Y):
    assert isinstance(X, Ragged)
    assert isinstance(Y, Ragged)
    assert_arrays_match(X.lengths, Y.lengths)
    assert_arrays_match(X.data, Y.data)
    return True


def assert_paddeds_match(X, Y):
    assert isinstance(X, Padded)
    assert isinstance(Y, Padded)
    assert_arrays_match(X.size_at_t, Y.size_at_t)
    assert assert_arrays_match(X.lengths, Y.lengths)
    assert assert_arrays_match(X.indices, Y.indices)
    assert X.data.dtype == Y.data.dtype
    assert X.data.shape[1] == Y.data.shape[1]
    assert X.data.shape[0] == Y.data.shape[0]
    return True


def assert_padded_data_match(X, Y):
    return assert_paddeds_match(Padded(*X), Padded(*Y))


def assert_ragged_data_match(X, Y):
    return assert_raggeds_match(Ragged(*X), Ragged(*Y))


def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_):
    assert isinstance(Y, ArgsKwargs)
    assert len(Y.args) == n_args
    assert list(Y.kwargs.keys()) == kwargs_keys
    assert all(isinstance(arg, type_) for arg in Y.args)
    assert all(isinstance(arg, type_) for arg in Y.kwargs.values())
    dX = backprop(Y)

    def is_supported_backend_array(arr):
        return is_cupy_array(arr) or is_numpy_array(arr)

    input_type = type(data) if not isinstance(data, list) else tuple
    assert isinstance(dX, input_type) or is_supported_backend_array(dX)

    if isinstance(data, dict):
        assert list(dX.keys()) == kwargs_keys
        assert all(is_supported_backend_array(arr) for arr in dX.values())
    elif isinstance(data, (list, tuple)):
        assert isinstance(dX, tuple)
        assert all(is_supported_backend_array(arr) for arr in dX)
    elif isinstance(data, ArgsKwargs):
        assert len(dX.args) == n_args
        assert list(dX.kwargs.keys()) == kwargs_keys

        assert all(is_supported_backend_array(arg) for arg in dX.args)
        assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values())
    elif not isinstance(data, numpy.ndarray):
        pytest.fail(f"Bad data type: {dX}")