File: store_ops_test_util.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (76 lines) | stat: -rw-r--r-- 2,213 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
## @package store_ops_test_util
# Module caffe2.distributed.store_ops_test_util





from multiprocessing import Process, Queue

import numpy as np

from caffe2.python import core, workspace


class StoreOpsTests(object):
    @classmethod
    def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs):
        store_handler = create_store_handler_fn()
        blob = "blob"
        value = np.full(1, 1, np.float32)

        # Use last process to set blob to make sure other processes
        # are waiting for the blob before it is set.
        if index == (num_procs - 1):
            workspace.FeedBlob(blob, value)
            workspace.RunOperatorOnce(
                core.CreateOperator(
                    "StoreSet",
                    [store_handler, blob],
                    [],
                    blob_name=blob))

        output_blob = "output_blob"
        workspace.RunOperatorOnce(
            core.CreateOperator(
                "StoreGet",
                [store_handler],
                [output_blob],
                blob_name=blob))

        try:
            np.testing.assert_array_equal(workspace.FetchBlob(output_blob), 1)
        except AssertionError as err:
            queue.put(err)

        workspace.ResetWorkspace()

    @classmethod
    def test_set_get(cls, create_store_handler_fn):
        # Queue for assertion errors on subprocesses
        queue = Queue()

        # Start N processes in the background
        num_procs = 4
        procs = []
        for index in range(num_procs):
            proc = Process(
                target=cls._test_set_get,
                args=(queue, create_store_handler_fn, index, num_procs, ))
            proc.start()
            procs.append(proc)

        # Test complete, join background processes
        for proc in procs:
            proc.join()

        # Raise first error we find, if any
        if not queue.empty():
            raise queue.get()

    @classmethod
    def test_get_timeout(cls, create_store_handler_fn):
        store_handler = create_store_handler_fn()
        net = core.Net('get_missing_blob')
        net.StoreGet([store_handler], 1, blob_name='blob')
        workspace.RunNetOnce(net)