1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
|
# Owner(s): ["module: inductor"]
import copy
import functools
import io
import sys
import tempfile
import unittest
from typing import Callable
from parameterized import parameterized_class
import torch
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.export import Dim
from torch.testing._internal.common_utils import IS_FBCODE, TEST_CUDA
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
def skipif(predicate: Callable[[str, bool], bool], reason: str):
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if predicate(self.device, self.package_cpp_only):
self.skipTest(reason)
return func(self, *args, **kwargs)
return wrapper
return decorator
def compile(
model,
args,
kwargs=None,
*,
dynamic_shapes=None,
package_path=None,
inductor_configs=None,
) -> AOTICompiledModel:
ep = torch.export.export(
model,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
package_path = torch._inductor.aoti_compile_and_package(
ep, package_path=package_path, inductor_configs=inductor_configs
) # type: ignore[arg-type]
loaded = load_package(package_path)
return loaded
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
@parameterized_class(
[
{"device": "cpu", "package_cpp_only": False},
]
+ (
[
# FIXME: AssertionError: AOTInductor compiled library does not exist at
{"device": "cpu", "package_cpp_only": True}
]
if not IS_FBCODE
else []
)
+ (
[
{"device": GPU_TYPE, "package_cpp_only": False},
{"device": GPU_TYPE, "package_cpp_only": True},
]
if sys.platform != "darwin"
else []
),
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}",
)
class TestAOTInductorPackage(TestCase):
def check_model(
self: TestCase,
model,
example_inputs,
inductor_configs=None,
dynamic_shapes=None,
disable_constraint_solver=False,
atol=None,
rtol=None,
) -> AOTICompiledModel:
with torch.no_grad():
torch.manual_seed(0)
model = model.to(self.device)
ref_model = copy.deepcopy(model)
ref_inputs = copy.deepcopy(example_inputs)
expected = ref_model(*ref_inputs)
inductor_configs = inductor_configs or {}
inductor_configs["aot_inductor.package_cpp_only"] = self.package_cpp_only
torch.manual_seed(0)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
compiled_model = compile(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
inductor_configs=inductor_configs,
package_path=f.name,
)
actual = compiled_model(*example_inputs)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
return compiled_model
def test_add(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_remove_intermediate_files(self):
# For CUDA, generated cpp files contain absolute path to the generated cubin files.
# With the package artifact, that cubin path should be overriden at the run time,
# so removing those intermeidate files in this test to verify that.
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model()
with torch.no_grad():
torch.manual_seed(0)
model = model.to(self.device)
ref_model = copy.deepcopy(model)
ref_inputs = copy.deepcopy(example_inputs)
expected = ref_model(*ref_inputs)
torch.manual_seed(0)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
ep = torch.export.export(
model,
example_inputs,
)
with fresh_inductor_cache():
# cubin files are removed when exiting this context
package_path = torch._inductor.aoti_compile_and_package(
ep,
package_path=f.name,
) # type: ignore[arg-type]
loaded = torch._inductor.aoti_load_package(package_path)
actual = loaded(*example_inputs)
self.assertEqual(actual, expected)
def test_linear(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
self.check_model(Model(), example_inputs)
def test_metadata(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
metadata = {"dummy": "moo"}
compiled_model = self.check_model(
Model(),
example_inputs,
inductor_configs={"aot_inductor.metadata": metadata},
)
loaded_metadata = compiled_model.get_metadata() # type: ignore[attr-defined]
self.assertEqual(loaded_metadata.get("dummy"), "moo")
def test_bool_input(self):
# Specialize on whichever branch the example input for b is
class Model(torch.nn.Module):
def forward(self, x, b):
if b:
return x * x
else:
return x + x
example_inputs = (torch.randn(3, 3, device=self.device), True)
self.check_model(Model(), example_inputs)
def test_multiple_methods(self):
options = {
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": self.package_cpp_only,
}
class Model1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.cat([a, b], dim=0)
b = torch.randn(3, 4, device=self.device)
dim0_a = Dim("dim0_a", min=1, max=10)
dim0_b = Dim("dim0_b", min=1, max=20)
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
example_inputs1 = (
torch.randn(2, 4, device=self.device),
torch.randn(3, 4, device=self.device),
)
ep1 = torch.export.export(
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
)
aoti_files1 = torch._inductor.aot_compile(
ep1.module(), example_inputs1, options=options
)
class Model2(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.device = device
def forward(self, x):
t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
t = torch.sqrt(t * 3)
return x * t
example_inputs2 = (torch.randn(5, 5, device=self.device),)
ep2 = torch.export.export(Model2(self.device), example_inputs2)
aoti_files2 = torch._inductor.aot_compile(
ep2.module(), example_inputs2, options=options
)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name, {"model1": aoti_files1, "model2": aoti_files2}
)
loaded1 = load_package(package_path, "model1")
loaded2 = load_package(package_path, "model2")
self.assertEqual(loaded1(*example_inputs1), ep1.module()(*example_inputs1))
self.assertEqual(loaded2(*example_inputs2), ep2.module()(*example_inputs2))
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_duplicate_calls(self):
options = {
"aot_inductor.package": True,
}
device = "cuda"
class Model1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.cat([a, b], dim=0)
dim0_a = Dim("dim0_a", min=1, max=10)
dim0_b = Dim("dim0_b", min=1, max=20)
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
example_inputs1 = (
torch.randn(2, 4, device=device),
torch.randn(3, 4, device=device),
)
self.check_model(Model1(), example_inputs1)
ep1 = torch.export.export(
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
)
aoti_files1 = torch._inductor.aot_compile(
ep1.module(), example_inputs1, options=options
)
device = "cpu"
example_inputs2 = (
torch.randn(2, 4, device=device),
torch.randn(3, 4, device=device),
)
ep2 = torch.export.export(
Model1(), example_inputs2, dynamic_shapes=dynamic_shapes
)
aoti_files2 = torch._inductor.aot_compile(
ep2.module(), example_inputs2, options=options
)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(
f.name, {"model1": aoti_files1, "model2": aoti_files2}
)
loaded1 = load_package(package_path, "model1")
loaded2 = load_package(package_path, "model2")
self.assertTrue(
torch.allclose(loaded1(*example_inputs1), ep1.module()(*example_inputs1))
)
self.assertTrue(
torch.allclose(loaded2(*example_inputs2), ep2.module()(*example_inputs2))
)
def test_specified_output_dir(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.cat([a, b], dim=0)
example_inputs = (
torch.randn(2, 4, device=self.device),
torch.randn(3, 4, device=self.device),
)
ep = torch.export.export(Model(), example_inputs)
aoti_files = torch._inductor.aot_compile(
ep.module(),
example_inputs,
options={
"aot_inductor.output_path": "tmp_output_",
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": self.package_cpp_only,
},
)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_path = package_aoti(f.name, {"model1": aoti_files})
loaded = load_package(package_path, "model1")
self.assertTrue(
torch.allclose(loaded(*example_inputs), ep.module()(*example_inputs))
)
def test_save_buffer(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b):
return torch.cat([a, b], dim=0)
example_inputs = (
torch.randn(2, 4, device=self.device),
torch.randn(3, 4, device=self.device),
)
ep = torch.export.export(Model(), example_inputs)
buffer = io.BytesIO()
buffer = torch._inductor.aoti_compile_and_package(
ep, package_path=buffer
) # type: ignore[arg-type]
for _ in range(2):
loaded = load_package(buffer)
self.assertTrue(
torch.allclose(loaded(*example_inputs), ep.module()(*example_inputs))
)
@skipif(
lambda device, package_cpp_only: device == "cpu" or package_cpp_only,
"No support for cpp only and cpu",
)
def test_package_without_weight(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.linear = torch.nn.Linear(k, n, device=device)
def forward(self, a):
return self.linear(a)
M, N, K = 128, 2048, 4096
model = Model(N, K, self.device)
example_inputs = (torch.randn(M, K, device=self.device),)
inductor_configs = {
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
}
compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
self.assertEqual(
set(compiled.get_constant_fqns()), set(model.state_dict().keys())
)
compiled.load_constants(model.state_dict(), check_full_update=True)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = compiled(test_inputs)
self.assertEqual(expected, output)
@skipif(
lambda device, package_cpp_only: device == "cpu" or package_cpp_only,
"No support for cpp only and cpu",
)
def test_update_weights(self):
class Model(torch.nn.Module):
def __init__(self, n, k, device):
super().__init__()
self.linear = torch.nn.Linear(k, n, device=device)
def forward(self, a):
return self.linear(a)
M, N, K = 128, 2048, 4096
model = Model(N, K, self.device)
example_inputs = (torch.randn(M, K, device=self.device),)
compiled = self.check_model(model, example_inputs)
new_state_dict = {
"linear.weight": torch.randn(N, K, device=self.device),
"linear.bias": torch.randn(N, device=self.device),
}
model.load_state_dict(new_state_dict)
compiled.load_constants(model.state_dict(), check_full_update=True)
test_inputs = torch.randn(M, K, device=self.device)
expected = model(test_inputs)
output = compiled(test_inputs)
self.assertEqual(expected, output)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
# cpp_extension N/A in fbcode
if HAS_GPU or sys.platform == "darwin":
run_tests(needs="filelock")
|