File: blobs_queue_db_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (109 lines) | stat: -rw-r--r-- 4,045 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109





import unittest
import numpy as np

import caffe2.proto.caffe2_pb2 as caffe2_pb2
from caffe2.python import core, workspace, timeout_guard


@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.")
class BlobsQueueDBTest(unittest.TestCase):
    def test_create_blobs_queue_db_string(self):
        device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
        with core.DeviceScope(device_opt):
            def add_blobs(queue, num_samples):
                blob = core.BlobReference("blob")
                status = core.BlobReference("blob_status")
                for i in range(num_samples):
                    self._add_blob_to_queue(
                        queue, self._create_test_tensor_protos(i), blob, status
                    )
            self._test_create_blobs_queue_db(add_blobs)

    def test_create_blobs_queue_db_tensor(self):
        device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
        with core.DeviceScope(device_opt):
            def add_blobs(queue, num_samples):
                blob = core.BlobReference("blob")
                status = core.BlobReference("blob_status")
                for i in range(num_samples):
                    data = self._create_test_tensor_protos(i)
                    data = np.array([data], dtype=str)
                    self._add_blob_to_queue(
                        queue, data, blob, status
                    )
            self._test_create_blobs_queue_db(add_blobs)

    def _test_create_blobs_queue_db(self, add_blobs_fun):
        device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
        with core.DeviceScope(device_opt):
            num_samples = 10000
            batch_size = 10
            init_net = core.Net('init_net')
            net = core.Net('test_create_blobs_queue_db')
            queue = init_net.CreateBlobsQueue([], 'queue', capacity=num_samples)
            reader = init_net.CreateBlobsQueueDB(
                [queue],
                'blobs_queue_db_reader',
                value_blob_index=0,
                timeout_secs=0.1,
            )
            workspace.RunNetOnce(init_net)
            add_blobs_fun(queue, num_samples)

            net.TensorProtosDBInput(
                [reader],
                ['image', 'label'],
                batch_size=batch_size
            )
            workspace.CreateNet(net)

            close_net = core.Net('close_net')
            close_net.CloseBlobsQueue([queue], [])

            for i in range(int(num_samples / batch_size)):
                with timeout_guard.CompleteInTimeOrDie(2.0):
                    workspace.RunNet(net)

                images = workspace.FetchBlob('image')
                labels = workspace.FetchBlob('label')
                self.assertEqual(batch_size, len(images))
                self.assertEqual(batch_size, len(labels))
                for idx, item in enumerate(images):
                    self.assertEqual(
                        "foo{}".format(i * batch_size + idx).encode('utf-8'), item
                    )
                for item in labels:
                    self.assertEqual(1, item)
            workspace.RunNetOnce(close_net)

    def _add_blob_to_queue(self, queue, data, blob, status):
        device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
        with core.DeviceScope(device_opt):
            workspace.FeedBlob(blob, data, core.DeviceOption(caffe2_pb2.CPU, 0))
            op = core.CreateOperator(
                "SafeEnqueueBlobs",
                [queue, blob],
                [blob, status],
            )

            workspace.RunOperatorOnce(op)

    def _create_test_tensor_protos(self, idx):
        item = caffe2_pb2.TensorProtos()
        data = item.protos.add()
        data.data_type = core.DataType.STRING
        data.string_data.append("foo{}".format(idx).encode('utf-8'))
        label = item.protos.add()
        label.data_type = core.DataType.INT32
        label.int32_data.append(1)

        return item.SerializeToString()

if __name__ == "__main__":
    import unittest
    unittest.main()