File: concat_op_cost_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (84 lines) | stat: -rw-r--r-- 2,858 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from collections import namedtuple

import numpy as np
from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase


class TestConcatOpCost(TestCase):
    def test_columnwise_concat(self):
        def _test_columnwise_concat_for_type(dtype):
            workspace.ResetWorkspace()
            workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
            workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=dtype))
            concat_op = core.CreateOperator(
                "Concat",
                ["input_1", "input_2"],
                ["output", "split_info"],
            )
            workspace.RunOperatorOnce(concat_op)

            output = workspace.FetchBlob("output")
            self.assertTupleEqual(output.shape, (2, 4))
            np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]])

            flops, bytes_written, bytes_read = workspace.GetOperatorCost(
                concat_op, concat_op.input
            )

            self.assertEqual(flops, 0)
            self.assertEqual(
                bytes_read,
                sum(workspace.FetchBlob(b).nbytes for b in concat_op.input),
            )
            self.assertEqual(
                bytes_written,
                sum(workspace.FetchBlob(b).nbytes for b in concat_op.output),
            )

        [
            _test_columnwise_concat_for_type(t)
            for t in [np.int64, np.float, np.half, np.int8]
        ]

    def test_split_then_concat(self):
        workspace.ResetWorkspace()
        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
        workspace.FeedBlob("split", np.array([1, 1, 1], dtype=np.int32))
        split_op = core.CreateOperator(
            "Split",
            ["input", "split"],
            ["output_1", "output_2", "output_3"],
            axis=1,
            add_axis=1,
        )
        workspace.RunOperatorOnce(split_op)

        concat_op = core.CreateOperator(
            "Concat",
            ["output_1", "output_2", "output_3"],
            ["output", "split_info"],
            axis=1,
            add_axis=1,
        )
        workspace.RunOperatorOnce(concat_op)

        np.testing.assert_array_equal(
            workspace.FetchBlob("input"), workspace.FetchBlob("output")
        )

        split_cost = workspace.GetOperatorCost(split_op, split_op.input)
        self.assertTupleEqual(
            split_cost,
            namedtuple("expected_cost", ["flops", "bytes_written", "bytes_read"])(
                0, 24, 36
            ),
        )

        concat_cost = workspace.GetOperatorCost(concat_op, concat_op.input)
        self.assertTupleEqual(
            concat_cost,
            namedtuple("expected_cost", ["flops", "bytes_written", "bytes_read"])(
                0, 36, 24
            ),
        )