File: list2ragged.py

package info (click to toggle)
python-thinc 9.1.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,896 kB
  • sloc: python: 17,122; javascript: 1,559; ansic: 342; makefile: 15; sh: 13
file content (25 lines) | stat: -rw-r--r-- 863 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from typing import Callable, List, Tuple, TypeVar, cast

from ..config import registry
from ..model import Model
from ..types import ArrayXd, ListXd, Ragged

InT = TypeVar("InT", bound=ListXd)
OutT = Ragged


@registry.layers("list2ragged.v1")
def list2ragged() -> Model[InT, OutT]:
    """Transform sequences to ragged arrays if necessary and return the ragged
    array. If sequences are already ragged, do nothing. A ragged array is a
    tuple (data, lengths), where data is the concatenated data.
    """
    return Model("list2ragged", forward)


def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]:
    def backprop(dYr: OutT) -> InT:
        return cast(InT, model.ops.unflatten(dYr.data, dYr.lengths))

    lengths = model.ops.asarray1i([len(x) for x in Xs])
    return Ragged(model.ops.flatten(Xs), lengths), backprop