1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
from caffe2.python import scope, core, workspace
import unittest
import threading
import time
SUCCESS_COUNT = 0
def thread_runner(idx, testobj):
global SUCCESS_COUNT
testobj.assertEquals(scope.CurrentNameScope(), "")
testobj.assertEquals(scope.CurrentDeviceScope(), None)
namescope = "namescope_{}".format(idx)
dsc = core.DeviceOption(workspace.GpuDeviceType, idx)
with scope.DeviceScope(dsc):
with scope.NameScope(namescope):
testobj.assertEquals(scope.CurrentNameScope(), namescope + "/")
testobj.assertEquals(scope.CurrentDeviceScope(), dsc)
time.sleep(0.01 + idx * 0.01)
testobj.assertEquals(scope.CurrentNameScope(), namescope + "/")
testobj.assertEquals(scope.CurrentDeviceScope(), dsc)
testobj.assertEquals(scope.CurrentNameScope(), "")
testobj.assertEquals(scope.CurrentDeviceScope(), None)
SUCCESS_COUNT += 1
class TestScope(unittest.TestCase):
def testNamescopeBasic(self):
self.assertEquals(scope.CurrentNameScope(), "")
with scope.NameScope("test_scope"):
self.assertEquals(scope.CurrentNameScope(), "test_scope/")
self.assertEquals(scope.CurrentNameScope(), "")
def testNamescopeAssertion(self):
self.assertEquals(scope.CurrentNameScope(), "")
try:
with scope.NameScope("test_scope"):
self.assertEquals(scope.CurrentNameScope(), "test_scope/")
raise Exception()
except Exception:
pass
self.assertEquals(scope.CurrentNameScope(), "")
def testEmptyNamescopeBasic(self):
self.assertEquals(scope.CurrentNameScope(), "")
with scope.NameScope("test_scope"):
with scope.EmptyNameScope():
self.assertEquals(scope.CurrentNameScope(), "")
self.assertEquals(scope.CurrentNameScope(), "test_scope/")
def testDevicescopeBasic(self):
self.assertEquals(scope.CurrentDeviceScope(), None)
dsc = core.DeviceOption(workspace.GpuDeviceType, 9)
with scope.DeviceScope(dsc):
self.assertEquals(scope.CurrentDeviceScope(), dsc)
self.assertEquals(scope.CurrentDeviceScope(), None)
def testEmptyDevicescopeBasic(self):
self.assertEquals(scope.CurrentDeviceScope(), None)
dsc = core.DeviceOption(workspace.GpuDeviceType, 9)
with scope.DeviceScope(dsc):
self.assertEquals(scope.CurrentDeviceScope(), dsc)
with scope.EmptyDeviceScope():
self.assertEquals(scope.CurrentDeviceScope(), None)
self.assertEquals(scope.CurrentDeviceScope(), dsc)
self.assertEquals(scope.CurrentDeviceScope(), None)
def testDevicescopeAssertion(self):
self.assertEquals(scope.CurrentDeviceScope(), None)
dsc = core.DeviceOption(workspace.GpuDeviceType, 9)
try:
with scope.DeviceScope(dsc):
self.assertEquals(scope.CurrentDeviceScope(), dsc)
raise Exception()
except Exception:
pass
self.assertEquals(scope.CurrentDeviceScope(), None)
def testTags(self):
self.assertEquals(scope.CurrentDeviceScope(), None)
extra_info1 = ["key1:value1"]
extra_info2 = ["key2:value2"]
extra_info3 = ["key3:value3"]
extra_info_1_2 = ["key1:value1", "key2:value2"]
extra_info_1_2_3 = ["key1:value1", "key2:value2", "key3:value3"]
with scope.DeviceScope(core.DeviceOption(0, extra_info=extra_info1)):
self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info1)
with scope.DeviceScope(core.DeviceOption(0, extra_info=extra_info2)):
self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info_1_2)
with scope.DeviceScope(core.DeviceOption(0, extra_info=extra_info3)):
self.assertEquals(
scope.CurrentDeviceScope().extra_info, extra_info_1_2_3
)
self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info_1_2)
self.assertEquals(scope.CurrentDeviceScope().extra_info, extra_info1)
self.assertEquals(scope.CurrentDeviceScope(), None)
def testMultiThreaded(self):
"""
Test that name/device scope are properly local to the thread
and don't interfere
"""
global SUCCESS_COUNT
self.assertEquals(scope.CurrentNameScope(), "")
self.assertEquals(scope.CurrentDeviceScope(), None)
threads = []
for i in range(4):
threads.append(threading.Thread(
target=thread_runner,
args=(i, self),
))
for t in threads:
t.start()
with scope.NameScope("master"):
self.assertEquals(scope.CurrentDeviceScope(), None)
self.assertEquals(scope.CurrentNameScope(), "master/")
for t in threads:
t.join()
self.assertEquals(scope.CurrentNameScope(), "master/")
self.assertEquals(scope.CurrentDeviceScope(), None)
# Ensure all threads succeeded
self.assertEquals(SUCCESS_COUNT, 4)
|