1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
import torch
from torch._export import aot_compile
from torch.export import Dim
torch.manual_seed(1337)
class Net(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w_pre = torch.randn(4, 4, device=device)
self.w_add = torch.randn(4, 4, device=device)
def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.w_add
return torch.matmul(x, w)
class NetWithTensorConstants(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w = torch.randn(30, 1, device="cuda")
def forward(self, x, y):
z = self.w * x * y
return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
data = {}
data_with_tensor_constants = {}
# Basice AOTI model test generation.
def generate_basic_tests():
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
for use_runtime_constant_folding in [True, False]:
if device == "cpu" and use_runtime_constant_folding:
# We do not test runtime const folding for cpu mode.
continue
model = Net(device).to(device=device)
x = torch.randn((4, 4), device=device)
with torch.no_grad():
ref_output = model(x)
torch._dynamo.reset()
with torch.no_grad():
dim0_x = Dim("dim0_x", min=1, max=1024)
dynamic_shapes = {"x": {0: dim0_x}}
model_so_path = aot_compile(
model,
(x,),
dynamic_shapes=dynamic_shapes,
options={
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
},
)
suffix = f"{device}"
if use_runtime_constant_folding:
suffix += "_use_runtime_constant_folding"
data.update(
{
f"model_so_path_{suffix}": model_so_path,
f"inputs_{suffix}": [x],
f"outputs_{suffix}": [ref_output],
f"w_pre_{suffix}": model.w_pre,
f"w_add_{suffix}": model.w_add,
}
)
# AOTI model which will create additional tensors during autograd.
def generate_test_with_additional_tensors():
if not torch.cuda.is_available():
return
model = NetWithTensorConstants()
x = torch.randn((30, 1), device="cuda")
y = torch.randn((30, 1), device="cuda")
with torch.no_grad():
ref_output = model(x, y)
torch._dynamo.reset()
with torch.no_grad():
model_so_path = aot_compile(model, (x, y))
data_with_tensor_constants.update(
{
"model_so_path": model_so_path,
"inputs": [x, y],
"outputs": [ref_output],
"w": model.w,
}
)
generate_basic_tests()
generate_test_with_additional_tensors()
# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
def __init__(self, data):
super().__init__()
for key in data:
setattr(self, key, data[key])
torch.jit.script(Serializer(data)).save("data.pt")
torch.jit.script(Serializer(data_with_tensor_constants)).save(
"data_with_tensor_constants.pt"
)
|