1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
# Owner(s): ["module: inductor"]
import sys
import unittest
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_memory_planning yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821
import torch
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_cpp_code
from torch.export import Dim
@requires_gpu()
@config.patch(memory_planning=True)
class TestMemoryPlanning(TestCase):
device = GPU_TYPE
def _generate(self, *, device):
"""
Generate a simple test case that has multiple simultaneously-live intermediate tensors.
"""
class Foo(torch.nn.Module):
def forward(self, x, y, z):
t0 = x.matmul(y)
t1 = x.matmul(z)
t0 = x.transpose(0, 1).matmul(t1)
t1 = x.matmul(t0)
return t0.sum() + t1.sum()
x = torch.randn((3, 2), device=device)
y = torch.randn((2, 4), device=device)
z = torch.randn((2, 3), device=device)
return (Foo(), (x, y, z))
def test_python_wrapper(self):
f, args = self._generate(device=GPU_TYPE)
compiled = torch.compile(f, dynamic=True)
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
).check(
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
).run(
code
)
self.assertTrue(same(f(*args), result))
def test_cpp_wrapper(self):
f, args = self._generate(device=GPU_TYPE)
compiled = torch.compile(f, dynamic=True)
with config.patch({"cpp_wrapper": True}):
result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check(
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_4, int_array_5, &tmp_tensor_handle_1)"
).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_1);").check(
"auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_2);"
).run(
code
)
self.assertTrue(same(f(*args), result))
@skipIfRocm(msg="test_aot_inductor doesn't work on ROCm")
@skipIfXpu(msg="aoti doesn't work on XPU")
def test_aoti(self):
try:
from .test_aot_inductor import AOTIRunnerUtil
except ImportError:
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
AOTIRunnerUtil,
)
f, args = self._generate(device=GPU_TYPE)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None)
result, code = run_and_get_cpp_code(
lambda: AOTIRunnerUtil.run(GPU_TYPE, f, args, dynamic_shapes=dynamic_shapes)
)
FileCheck().check(
"int64_t int_array_2[] = {24L + align(12L*s0), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;"
).check_next(
"aoti_torch_empty_strided(1, int_array_2, int_array_3,"
).check_next(
"RAIIAtenTensorHandle pool1(pool1_handle);"
).check_next(
"int64_t int_array_4[] = {s0, 3L};"
).check_next(
"int64_t int_array_5[] = {3L, 1L};"
).check_next(
"AtenTensorHandle tmp_tensor_handle_1;"
).check_next(
"aoti_torch__alloc_from_pool(pool1, 0"
).run(
code
)
self.assertTrue(same(f(*args), result))
if __name__ == "__main__":
if HAS_GPU:
run_tests()
|