File: lengths_pad_op_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (57 lines) | stat: -rw-r--r-- 1,625 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57





from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
import hypothesis.strategies as st
import numpy as np


class TestLengthsPadOp(serial.SerializedTestCase):

    @serial.given(
        inputs=hu.lengths_tensor(
            dtype=np.float32,
            min_value=1,
            max_value=5,
            allow_empty=True,
        ),
        delta_length=st.integers(0, 10),
        padding_value=st.floats(-10.0, 10.0),
        **hu.gcs
    )
    def test_lengths_pad(self, inputs, delta_length, padding_value, gc, dc):
        data, lengths = inputs
        max_length = np.max(lengths) if len(lengths) > 0 else 0
        target_length = max(max_length + delta_length, 1)

        def lengths_pad_op(data, lengths):
            N = len(lengths)
            output = np.ndarray(
                shape=(target_length * N, ) + data.shape[1:], dtype=np.float32)
            output.fill(padding_value)
            ptr1, ptr2 = 0, 0
            for i in range(N):
                output[ptr1:ptr1 + lengths[i]] = data[ptr2:ptr2 + lengths[i]]
                ptr1 += target_length
                ptr2 += lengths[i]

            return [output]

        op = core.CreateOperator(
            "LengthsPad",
            ["data", "lengths"],
            ["data_padded"],
            target_length=target_length,
            padding_value=padding_value,
        )

        self.assertReferenceChecks(
            device_option=gc,
            op=op,
            inputs=[data, lengths],
            reference=lengths_pad_op,
        )