1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
|
import torch
import torch._inductor.config
from torch._export import aot_compile
from torch.export import Dim
torch.manual_seed(1337)
class Net(torch.nn.Module):
def __init__(self, device, size=4):
super().__init__()
self.w_pre = torch.randn(size, size, device=device)
self.w_add = torch.randn(size, size, device=device)
def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.w_add
return torch.matmul(x, w)
class NetWithTensorConstants(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w = torch.randn(30, 1, device="cuda")
def forward(self, x, y):
z = self.w * x * y
return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
data = {}
large_data = {}
cuda_alloc_data = {}
data_with_tensor_constants = {}
# Basice AOTI model test generation.
def generate_basic_tests():
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
for use_runtime_constant_folding in [True, False]:
if device == "cpu" and use_runtime_constant_folding:
# We do not test runtime const folding for cpu mode.
continue
model = Net(device).to(device=device)
x = torch.randn((4, 4), device=device)
with torch.no_grad():
ref_output = model(x)
torch._dynamo.reset()
with torch.no_grad():
dim0_x = Dim("dim0_x", min=1, max=1024)
dynamic_shapes = {"x": {0: dim0_x}}
model_so_path = aot_compile(
model,
(x,),
dynamic_shapes=dynamic_shapes,
options={
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
},
)
# Also store a .pt2 file using the aoti_compile_and_package API
pt2_package_path = torch._inductor.aoti_compile_and_package(
torch.export.export(
model,
(x,),
dynamic_shapes=dynamic_shapes,
),
inductor_configs={
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
},
)
suffix = f"{device}"
if use_runtime_constant_folding:
suffix += "_use_runtime_constant_folding"
data.update(
{
f"model_so_path_{suffix}": model_so_path,
f"pt2_package_path_{suffix}": pt2_package_path,
f"inputs_{suffix}": [x],
f"outputs_{suffix}": [ref_output],
f"w_pre_{suffix}": model.w_pre,
f"w_add_{suffix}": model.w_add,
}
)
def generate_basic_tests_consts_cpp():
backup_consts_asm_cfg: bool = (
torch._inductor.config.aot_inductor.use_consts_asm_build
)
torch._inductor.config.aot_inductor.use_consts_asm_build = False
# Test consts cpp build again.
generate_basic_tests()
torch._inductor.config.aot_inductor.use_consts_asm_build = backup_consts_asm_cfg
def generate_large_tests():
device = "cuda"
model = Net(device, size=4096).to(device=device)
x = torch.randn((4096, 4096), device=device)
with torch.no_grad():
ref_output = model(x)
torch._dynamo.reset()
for use_runtime_constant_folding in [True, False]:
with torch.no_grad():
model_so_path = aot_compile(
model,
(x,),
options={
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
},
)
# Also store a .pt2 file using the aoti_compile_and_package API
pt2_package_path = torch._inductor.aoti_compile_and_package(
torch.export.export(
model,
(x,),
),
inductor_configs={
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
},
)
suffix = "_use_runtime_constant_folding" if use_runtime_constant_folding else ""
large_data.update(
{ # noqa: F541
f"model_so_path{suffix}": model_so_path,
f"pt2_package_path{suffix}": pt2_package_path,
"inputs": [x],
"outputs": [ref_output],
"w_pre": model.w_pre,
"w_add": model.w_add,
}
)
def generate_cuda_alloc_test():
device = "cuda"
model = Net(device, size=4096).to(device=device)
x = torch.randn((4096, 4096), device=device)
with torch.no_grad():
ref_output = model(x)
torch._dynamo.reset()
with torch.no_grad():
model_so_path = aot_compile(
model,
(x,),
options={"aot_inductor.weight_use_caching_allocator": True},
)
cuda_alloc_data.update(
{ # noqa: F541
"model_so_path": model_so_path,
"inputs": [x],
"outputs": [ref_output],
"w_pre": model.w_pre,
"w_add": model.w_add,
}
)
# AOTI model which will create additional tensors during autograd.
def generate_test_with_additional_tensors():
if not torch.cuda.is_available():
return
model = NetWithTensorConstants()
x = torch.randn((30, 1), device="cuda")
y = torch.randn((30, 1), device="cuda")
with torch.no_grad():
ref_output = model(x, y)
torch._dynamo.reset()
with torch.no_grad():
model_so_path = aot_compile(model, (x, y))
# Also store a .pt2 file using the aoti_compile_and_package API
pt2_package_path = torch._inductor.aoti_compile_and_package(
torch.export.export(model, (x, y))
)
data_with_tensor_constants.update(
{
"model_so_path": model_so_path,
"pt2_package_path": pt2_package_path,
"inputs": [x, y],
"outputs": [ref_output],
"w": model.w,
}
)
generate_basic_tests()
generate_basic_tests_consts_cpp()
generate_large_tests()
generate_test_with_additional_tensors()
generate_cuda_alloc_test()
# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
def __init__(self, data):
super().__init__()
for key in data:
setattr(self, key, data[key])
torch.jit.script(Serializer(data)).save("data.pt")
torch.jit.script(Serializer(large_data)).save("large_data.pt")
torch.jit.script(Serializer(data_with_tensor_constants)).save(
"data_with_tensor_constants.pt"
)
torch.jit.script(Serializer(cuda_alloc_data)).save("cuda_alloc_data.pt")
|