1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
# Owner(s): ["module: cuda"]
import sys
import unittest
import unittest.mock
import torch
import torch.utils._cuda_trace as cuda_trace
from torch.testing._internal.common_utils import TestCase, run_tests
# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
# and make sure the CUDA streams are initialized for each test that uses them.
# We cannot import TEST_CUDA from torch.testing._internal.common_cuda here,
# because if we do that, the TEST_CUDNN line from torch.testing._internal.common_cuda will be executed
# multiple times as well during the execution of this test suite, and it will
# cause CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
if not TEST_CUDA:
print("CUDA not available, skipping tests", file=sys.stderr)
TestCase = object # noqa: F811
class TestCudaTrace(TestCase):
def setUp(self):
torch._C._activate_cuda_trace()
self.mock = unittest.mock.MagicMock()
def test_event_creation_callback(self):
cuda_trace.register_callback_for_cuda_event_creation(self.mock)
event = torch.cuda.Event()
event.record()
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_event_deletion_callback(self):
cuda_trace.register_callback_for_cuda_event_deletion(self.mock)
event = torch.cuda.Event()
event.record()
event_id = event._as_parameter_.value
del event
self.mock.assert_called_once_with(event_id)
def test_event_record_callback(self):
cuda_trace.register_callback_for_cuda_event_record(self.mock)
event = torch.cuda.Event()
event.record()
self.mock.assert_called_once_with(
event._as_parameter_.value, torch.cuda.default_stream().cuda_stream
)
def test_event_wait_callback(self):
cuda_trace.register_callback_for_cuda_event_wait(self.mock)
event = torch.cuda.Event()
event.record()
event.wait()
self.mock.assert_called_once_with(
event._as_parameter_.value, torch.cuda.default_stream().cuda_stream
)
def test_memory_allocation_callback(self):
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())
def test_memory_deallocation_callback(self):
cuda_trace.register_callback_for_cuda_memory_deallocation(self.mock)
tensor = torch.empty(3, 8, device="cuda")
data_ptr = tensor.data_ptr()
del tensor
self.mock.assert_called_once_with(data_ptr)
def test_stream_creation_callback(self):
cuda_trace.register_callback_for_cuda_stream_creation(self.mock)
torch.cuda.Stream()
self.mock.assert_called()
def test_device_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_device_synchronization(self.mock)
torch.cuda.synchronize()
self.mock.assert_called()
def test_stream_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
stream = torch.cuda.Stream()
stream.synchronize()
self.mock.assert_called_once_with(stream.cuda_stream)
def test_event_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_event_synchronization(self.mock)
event = torch.cuda.Event()
event.record()
event.synchronize()
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_memcpy_synchronization(self):
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
tensor = torch.rand(5, device="cuda")
tensor.nonzero()
self.mock.assert_called_once_with(torch.cuda.default_stream().cuda_stream)
def test_all_trace_callbacks_called(self):
other = unittest.mock.MagicMock()
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
cuda_trace.register_callback_for_cuda_memory_allocation(other)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())
other.assert_called_once_with(tensor.data_ptr())
if __name__ == "__main__":
run_tests()
|