1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
|
# Owner(s): ["module: inductor"]
import os
import sys
import unittest
import torch
import torch._dynamo
import torch.utils.cpp_extension
from torch._C import FileCheck
try:
from extension_backends.cpp.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950
ExtensionCppWrapperCodegen,
ExtensionScheduling,
ExtensionWrapperCodegen,
)
except ImportError:
from .extension_backends.cpp.extension_codegen_backend import (
ExtensionCppWrapperCodegen,
ExtensionScheduling,
ExtensionWrapperCodegen,
)
from filelock import FileLock, Timeout
import torch._inductor.config as config
from torch._inductor import cpu_vec_isa, metrics
from torch._inductor.codegen import cpp_utils
from torch._inductor.codegen.common import (
get_scheduling_for_device,
get_wrapper_codegen_for_device,
register_backend_for_device,
)
from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS, xfailIfS390X
try:
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
TestCase = test_torchinductor.TestCase
@xfailIfS390X
class BaseExtensionBackendTests(TestCase):
module = None
# Use a lock file so that only one test can build this extension at a time
lock_file = "extension_device.lock"
lock = FileLock(lock_file)
@classmethod
def setUpClass(cls):
super().setUpClass()
try:
cls.lock.acquire(timeout=600)
except Timeout:
# This shouldn't happen, still attempt to build the extension anyway
pass
# Build Extension
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
source_file_path = os.path.dirname(os.path.abspath(__file__))
source_file = os.path.join(
source_file_path, "extension_backends/cpp/extension_device.cpp"
)
cls.module = torch.utils.cpp_extension.load(
name="extension_device",
sources=[
str(source_file),
],
extra_cflags=["-g"],
verbose=True,
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
if os.path.exists(cls.lock_file):
os.remove(cls.lock_file)
cls.lock.release()
def setUp(self):
torch._dynamo.reset()
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
assert self.module is not None
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
class ExtensionBackendTests(BaseExtensionBackendTests):
def test_open_device_registration(self):
torch.utils.rename_privateuse1_backend("extension_device")
torch._register_device_module("extension_device", self.module)
register_backend_for_device(
"extension_device",
ExtensionScheduling,
ExtensionWrapperCodegen,
ExtensionCppWrapperCodegen,
)
self.assertTrue(
get_scheduling_for_device("extension_device") == ExtensionScheduling
)
self.assertTrue(
get_wrapper_codegen_for_device("extension_device")
== ExtensionWrapperCodegen
)
self.assertTrue(
get_wrapper_codegen_for_device("extension_device", True)
== ExtensionCppWrapperCodegen
)
self.assertFalse(self.module.custom_op_called())
device = self.module.custom_device()
x = torch.empty(2, 16).to(device=device).fill_(1)
self.assertTrue(self.module.custom_op_called())
y = torch.empty(2, 16).to(device=device).fill_(2)
z = torch.empty(2, 16).to(device=device).fill_(3)
ref = torch.empty(2, 16).fill_(5)
self.assertTrue(x.device == device)
self.assertTrue(y.device == device)
self.assertTrue(z.device == device)
def fn(a, b, c):
return a * b + c
cpp_utils.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1"
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
metrics.reset()
opt_fn = torch.compile()(fn)
_, code = run_and_get_cpp_code(opt_fn, x, y, z)
if (
cpu_vec_isa.valid_vec_isa_list()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
):
load_expr = "loadu"
else:
load_expr = " = in_ptr0[static_cast<long>(i0)];"
FileCheck().check("void").check(load_expr).check(
"extension_device"
).run(code)
opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res.to(device="cpu"))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU
# cpp_extension doesn't work in fbcode right now
if HAS_CPU and not IS_MACOS and not IS_FBCODE:
run_tests(needs="filelock")
|