File: parallel_workers_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (119 lines) | stat: -rw-r--r-- 3,501 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119





import unittest

from caffe2.python import workspace, core
import caffe2.python.parallel_workers as parallel_workers


def create_queue():
    queue = 'queue'

    workspace.RunOperatorOnce(
        core.CreateOperator(
            "CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000
        )
    )
    # Technically, blob creations aren't thread safe. Since the unittest below
    # does RunOperatorOnce instead of CreateNet+RunNet, we have to precreate
    # all blobs beforehand
    for i in range(100):
        workspace.C.Workspace.current.create_blob("blob_" + str(i))
        workspace.C.Workspace.current.create_blob("status_blob_" + str(i))
    workspace.C.Workspace.current.create_blob("dequeue_blob")
    workspace.C.Workspace.current.create_blob("status_blob")

    return queue


def create_worker(queue, get_blob_data):
    def dummy_worker(worker_id):
        blob = 'blob_' + str(worker_id)

        workspace.FeedBlob(blob, get_blob_data(worker_id))

        workspace.RunOperatorOnce(
            core.CreateOperator(
                'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob_' + str(worker_id)]
            )
        )

    return dummy_worker


def dequeue_value(queue):
    dequeue_blob = 'dequeue_blob'
    workspace.RunOperatorOnce(
        core.CreateOperator(
            "SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob']
        )
    )

    return workspace.FetchBlob(dequeue_blob)


class ParallelWorkersTest(unittest.TestCase):
    def testParallelWorkers(self):
        workspace.ResetWorkspace()

        queue = create_queue()
        dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
        worker_coordinator = parallel_workers.init_workers(dummy_worker)
        worker_coordinator.start()

        for _ in range(10):
            value = dequeue_value(queue)
            self.assertTrue(
                value in [b'0', b'1'], 'Got unexpected value ' + str(value)
            )

        self.assertTrue(worker_coordinator.stop())

    def testParallelWorkersInitFun(self):
        workspace.ResetWorkspace()

        queue = create_queue()
        dummy_worker = create_worker(
            queue, lambda worker_id: workspace.FetchBlob('data')
        )
        workspace.FeedBlob('data', 'not initialized')

        def init_fun(worker_coordinator, global_coordinator):
            workspace.FeedBlob('data', 'initialized')

        worker_coordinator = parallel_workers.init_workers(
            dummy_worker, init_fun=init_fun
        )
        worker_coordinator.start()

        for _ in range(10):
            value = dequeue_value(queue)
            self.assertEqual(
                value, b'initialized', 'Got unexpected value ' + str(value)
            )

        # A best effort attempt at a clean shutdown
        worker_coordinator.stop()

    def testParallelWorkersShutdownFun(self):
        workspace.ResetWorkspace()

        queue = create_queue()
        dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
        workspace.FeedBlob('data', 'not shutdown')

        def shutdown_fun():
            workspace.FeedBlob('data', 'shutdown')

        worker_coordinator = parallel_workers.init_workers(
            dummy_worker, shutdown_fun=shutdown_fun
        )
        worker_coordinator.start()

        self.assertTrue(worker_coordinator.stop())

        data = workspace.FetchBlob('data')
        self.assertEqual(data, b'shutdown', 'Got unexpected value ' + str(data))