1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
# Owner(s): ["oncall: profiler"]
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
None
from typing import Any, Dict
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.datapipes as dp
from torch.autograd import (
_record_function_with_args_enter,
_record_function_with_args_exit,
)
from torch.autograd.profiler import profile as _profile
from torch.profiler import kineto_available, record_function
from torch.testing._internal.common_utils import run_tests, TestCase
Json = Dict[str, Any]
class TestRecordFunction(TestCase):
def _record_function_with_param(self):
u = torch.randn(3, 4, 5, requires_grad=True)
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
with record_function("## TEST 1 ##", "1, 2, 3"):
rf_handle = _record_function_with_args_enter(
"## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u
)
_record_function_with_args_exit(rf_handle)
with record_function("## TEST 3 ##"):
rf_handle = _record_function_with_args_enter("## TEST 4 ##")
_record_function_with_args_exit(rf_handle)
return prof
def test_record_function(self):
prof_result = self._record_function_with_param()
found_test_1 = False
found_test_2 = False
found_test_3 = False
found_test_4 = False
for e in prof_result.function_events:
if "## TEST 1 ##" == e.name:
found_test_1 = True
self.assertTrue(e.input_shapes == [[]])
elif "## TEST 2 ##" == e.name:
found_test_2 = True
self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]])
elif "## TEST 3 ##" == e.name:
found_test_3 = True
self.assertTrue(e.input_shapes == [])
elif "## TEST 4 ##" == e.name:
found_test_4 = True
self.assertTrue(e.input_shapes == [])
self.assertTrue(found_test_1)
self.assertTrue(found_test_2)
self.assertTrue(found_test_3)
self.assertTrue(found_test_4)
def test_datapipe_with_record_function(self):
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
input_dp1 = dp.iter.IterableWrapper(range(4))
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
output_dp = input_dp1.mux(input_dp2, input_dp3)
output = list(output_dp)
has_iter = False
has_mux = False
for e in prof.function_events:
if has_iter and has_mux:
break
if not has_iter and "IterableWrapper" in e.name:
has_iter = True
if not has_mux and "Multiplexer" in e.name:
has_mux = True
self.assertTrue(has_iter)
self.assertTrue(has_mux)
def test_datapipe_delegation_with_profiler(self):
class IDPIterator(torch.utils.data.IterDataPipe):
def __init__(self) -> None:
self.data = list(range(10))
self._idx = 0
def __iter__(self):
return self
def __next__(self):
if self._idx >= 10:
self._idx = 0
raise StopIteration
self._idx += 1
return self.data[self._idx - 1]
def get_value(self, idx):
return self.data[idx]
dp1 = IDPIterator() # The object itself is an iterator
self.assertEqual(5, dp1.get_value(5))
it_dp1 = iter(dp1) # This creates the 1st iterator
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
self.assertEqual(list(range(10)), list(it_dp1))
class IDPDelegator(torch.utils.data.IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
return iter(self.datapipe)
dp2 = IDPDelegator(dp1)
it_dp2 = iter(dp2)
self.assertEqual(5, it_dp2.get_value(5))
self.assertEqual(list(range(10)), list(it_dp2))
def test_datapipe_with_record_function_fork(self):
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
input_dp = dp.iter.IterableWrapper(range(10))
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
output1 = list(dp1)
has_iter = False
has_child = False
for e in prof.function_events:
if has_iter and has_child:
break
if not has_iter and "IterableWrapper" in e.name:
has_iter = True
if not has_child and "_ChildDataPipe" in e.name:
has_child = True
self.assertTrue(has_iter)
self.assertTrue(has_child)
if __name__ == "__main__":
run_tests()
|