File: ctc_ops_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (108 lines) | stat: -rw-r--r-- 3,915 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108




import numpy as np
from caffe2.proto import caffe2_pb2

from caffe2.python import core, workspace, dyndep, test_util

dyndep.InitOpsLibrary('@/caffe2/caffe2/contrib/warpctc:ctc_ops')
workspace.GlobalInit(["python"])


def softmax(w):
    maxes = np.amax(w, axis=-1, keepdims=True)
    e = np.exp(w - maxes)
    dist = e / np.sum(e, axis=-1, keepdims=True)
    return dist


class CTCOpsTest(test_util.TestCase):
    def verify_cost(self, device_option, is_test, skip_input_lengths=False):
        alphabet_size = 5
        N = 1
        T = 2

        inputs = np.asarray(
            [
                [[0.1, 0.6, 0.1, 0.1, 0.1]],
                [[0.1, 0.1, 0.6, 0.1, 0.1]],
            ]
        ).reshape(T, N, alphabet_size).astype(np.float32)

        labels = np.asarray([1, 2]).astype(np.int32).reshape(T)
        label_lengths = np.asarray([2]).astype(np.int32).reshape(N)
        input_lengths = np.asarray([T]).astype(np.int32)

        net = core.Net("test-net")
        input_blobs = ["inputs", "labels", "label_lengths"]
        if not skip_input_lengths:
            input_blobs.append("input_lengths")
        output_blobs = ["costs", "workspace"] if is_test \
                else ["inputs_grad_to_be_copied", "costs", "workspace"]
        net.CTC(input_blobs,
                output_blobs,
                is_test=is_test,
                device_option=device_option)
        if not is_test:
            net.AddGradientOperators(["costs"])
        self.ws.create_blob("inputs").feed(inputs, device_option=device_option)
        self.ws.create_blob("labels").feed(labels)
        self.ws.create_blob("label_lengths").feed(label_lengths)
        if not skip_input_lengths:
            self.ws.create_blob("input_lengths").feed(input_lengths)
        self.ws.run(net)
        probs = softmax(inputs)
        expected = probs[0, 0, 1] * probs[1, 0, 2]
        self.assertEqual(self.ws.blobs["costs"].fetch().shape, (N,))
        self.assertEqual(self.ws.blobs["costs"].fetch().dtype, np.float32)
        cost = self.ws.blobs["costs"].fetch()[0]
        print(cost)
        self.assertAlmostEqual(np.exp(-cost), expected)
        if not is_test:
            # Make sure inputs_grad was added by AddGradientOperators and
            # it is equal to the inputs_grad_to_be_copied blob returned by CTCop
            assert np.array_equal(
                self.ws.blobs["inputs_grad"].fetch(),
                self.ws.blobs["inputs_grad_to_be_copied"].fetch()
            )

    def test_ctc_cost_cpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=False)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=False, skip_input_lengths=True)

    def test_ctc_cost_gpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=False)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=False,
            skip_input_lengths=True)

    def test_ctc_forward_only_cpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=True)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=True,
            skip_input_lengths=True)

    def test_ctc_forward_only_gpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=True)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=True,
            skip_input_lengths=True)