File: functional.py

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (138 lines) | stat: -rw-r--r-- 4,881 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import Any, List, Optional

import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

__all__ = [
    "to_tensor",
    "truncate",
    "add_token",
    "str_to_int",
]


def to_tensor(input: Any, padding_value: Optional[int] = None, dtype: torch.dtype = torch.long) -> Tensor:
    r"""Convert input to torch tensor

    :param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch.
    :type padding_value: Optional[int]
    :param dtype: :class:`torch.dtype` of output tensor
    :type dtype: :class:`torch.dtype`
    :param input: Sequence or batch of token ids
    :type input: Union[List[int], List[List[int]]]
    :rtype: Tensor
    """
    if torch.jit.isinstance(input, List[int]):
        return torch.tensor(input, dtype=torch.long)
    elif torch.jit.isinstance(input, List[List[int]]):
        if padding_value is None:
            output = torch.tensor(input, dtype=dtype)
            return output
        else:
            output = pad_sequence(
                [torch.tensor(ids, dtype=dtype) for ids in input], batch_first=True, padding_value=float(padding_value)
            )
            return output
    else:
        raise TypeError("Input type not supported")


def truncate(input: Any, max_seq_len: int) -> Any:
    """Truncate input sequence or batch

    :param input: Input sequence or batch to be truncated
    :type input: Union[List[Union[str, int]], List[List[Union[str, int]]]]
    :param max_seq_len: Maximum length beyond which input is discarded
    :type max_seq_len: int
    :return: Truncated sequence
    :rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]]
    """
    if torch.jit.isinstance(input, List[int]):
        return input[:max_seq_len]
    elif torch.jit.isinstance(input, List[str]):
        return input[:max_seq_len]
    elif torch.jit.isinstance(input, List[List[int]]):
        output: List[List[int]] = []
        for ids in input:
            output.append(ids[:max_seq_len])
        return output
    elif torch.jit.isinstance(input, List[List[str]]):
        output: List[List[str]] = []
        for ids in input:
            output.append(ids[:max_seq_len])
        return output
    else:
        raise TypeError("Input type not supported")


def add_token(input: Any, token_id: Any, begin: bool = True) -> Any:
    """Add token to start or end of sequence

    :param input: Input sequence or batch
    :type input: Union[List[Union[str, int]], List[List[Union[str, int]]]]
    :param token_id: token to be added
    :type token_id: Union[str, int]
    :param begin: Whether to insert token at start or end or sequence, defaults to True
    :type begin: bool, optional
    :return: sequence or batch with token_id added to begin or end or input
    :rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]]
    """
    if torch.jit.isinstance(input, List[int]) and torch.jit.isinstance(token_id, int):
        if begin:
            return [token_id] + input
        else:
            return input + [token_id]
    elif torch.jit.isinstance(input, List[str]) and torch.jit.isinstance(token_id, str):
        if begin:
            return [token_id] + input
        else:
            return input + [token_id]
    elif torch.jit.isinstance(input, List[List[int]]) and torch.jit.isinstance(token_id, int):
        output: List[List[int]] = []

        if begin:
            for ids in input:
                output.append([token_id] + ids)
        else:
            for ids in input:
                output.append(ids + [token_id])

        return output
    elif torch.jit.isinstance(input, List[List[str]]) and torch.jit.isinstance(token_id, str):
        output: List[List[str]] = []
        if begin:
            for ids in input:
                output.append([token_id] + ids)
        else:
            for ids in input:
                output.append(ids + [token_id])

        return output
    else:
        raise TypeError("Input type not supported")


def str_to_int(input: Any) -> Any:
    """Convert string tokens to integers (either single sequence or batch).

    :param input: Input sequence or batch
    :type input: Union[List[str], List[List[str]]]
    :return: Sequence or batch of string tokens converted to integers
    :rtype: Union[List[int], List[List[int]]]
    """
    if torch.jit.isinstance(input, List[str]):
        output: List[int] = []
        for element in input:
            output.append(int(element))
        return output
    if torch.jit.isinstance(input, List[List[str]]):
        output: List[List[int]] = []
        for ids in input:
            current: List[int] = []
            for element in ids:
                current.append(int(element))
            output.append(current)
        return output
    else:
        raise TypeError("Input type not supported")