1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
# Owner(s): ["module: mtia"]
import os
import tempfile
import unittest
import torch
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_LINUX,
skipIfTorchDynamo,
TEST_CUDA,
TEST_PRIVATEUSE1,
TEST_XPU,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
# define TEST_ROCM before changing TEST_CUDA
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
@unittest.skipIf(
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU,
"Only on linux platform and mutual exclusive to other backends",
)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppExtensionMTIABackend(common.TestCase):
"""Tests MTIA backend with C++ extensions."""
module = None
def setUp(self):
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def tearDown(self):
super().tearDown()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
@classmethod
def tearDownClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
@classmethod
def setUpClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
build_dir = tempfile.mkdtemp()
# Load the fake device guard impl.
cls.module = torch.utils.cpp_extension.load(
name="mtia_extension",
sources=["cpp_extensions/mtia_extension.cpp"],
build_directory=build_dir,
extra_include_paths=[
"cpp_extensions",
"path / with spaces in it",
"path with quote'",
],
is_python_module=False,
verbose=True,
)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_get_device_module(self):
device = torch.device("mtia:0")
default_stream = torch.get_device_module(device).current_stream()
self.assertEqual(
default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA)
)
print(torch._C.Stream.__mro__)
print(torch.cuda.Stream.__mro__)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_basic(self):
default_stream = torch.mtia.current_stream()
user_stream = torch.mtia.Stream()
self.assertEqual(torch.mtia.current_stream(), default_stream)
self.assertNotEqual(default_stream, user_stream)
# Check mtia_extension.cpp, default stream id starts from 0.
self.assertEqual(default_stream.stream_id, 0)
self.assertNotEqual(user_stream.stream_id, 0)
with torch.mtia.stream(user_stream):
self.assertEqual(torch.mtia.current_stream(), user_stream)
self.assertTrue(user_stream.query())
default_stream.synchronize()
self.assertTrue(default_stream.query())
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_context(self):
mtia_stream_0 = torch.mtia.Stream(device="mtia:0")
mtia_stream_1 = torch.mtia.Stream(device="mtia:0")
print(mtia_stream_0)
print(mtia_stream_1)
with torch.mtia.stream(mtia_stream_0):
current_stream = torch.mtia.current_stream()
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
with torch.mtia.stream(mtia_stream_1):
current_stream = torch.mtia.current_stream()
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_context_different_device(self):
device_0 = torch.device("mtia:0")
device_1 = torch.device("mtia:1")
mtia_stream_0 = torch.mtia.Stream(device=device_0)
mtia_stream_1 = torch.mtia.Stream(device=device_1)
print(mtia_stream_0)
print(mtia_stream_1)
orig_current_device = torch.mtia.current_device()
with torch.mtia.stream(mtia_stream_0):
current_stream = torch.mtia.current_stream()
self.assertTrue(torch.mtia.current_device() == device_0.index)
msg = f"current_stream {current_stream} should be {mtia_stream_0}"
self.assertTrue(current_stream == mtia_stream_0, msg=msg)
self.assertTrue(torch.mtia.current_device() == orig_current_device)
with torch.mtia.stream(mtia_stream_1):
current_stream = torch.mtia.current_stream()
self.assertTrue(torch.mtia.current_device() == device_1.index)
msg = f"current_stream {current_stream} should be {mtia_stream_1}"
self.assertTrue(current_stream == mtia_stream_1, msg=msg)
self.assertTrue(torch.mtia.current_device() == orig_current_device)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_device_context(self):
device_0 = torch.device("mtia:0")
device_1 = torch.device("mtia:1")
with torch.mtia.device(device_0):
self.assertTrue(torch.mtia.current_device() == device_0.index)
with torch.mtia.device(device_1):
self.assertTrue(torch.mtia.current_device() == device_1.index)
if __name__ == "__main__":
common.run_tests()
|