1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
import torch
from torch.export import Dim
# custom op that loads the aot-compiled model
AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so"
torch.classes.load_library(AOTI_CUSTOM_OP_LIB)
class TensorSerializer(torch.nn.Module):
def __init__(self, data):
super().__init__()
for key in data:
setattr(self, key, data[key])
class SimpleModule(torch.nn.Module):
"""
a simple module to be compiled
"""
def __init__(self) -> None:
super().__init__()
self.fc = torch.nn.Linear(4, 6)
self.relu = torch.nn.ReLU()
def forward(self, x):
a = self.fc(x)
b = self.relu(a)
return b
class MyAOTIModule(torch.nn.Module):
"""
a wrapper nn.Module that instantiates its forward method
on MyAOTIClass
"""
def __init__(self, lib_path, device):
super().__init__()
self.aoti_custom_op = torch.classes.aoti.MyAOTIClass(
lib_path,
device,
)
def forward(self, *x):
outputs = self.aoti_custom_op.forward(x)
return tuple(outputs)
def make_script_module(lib_path, device, *inputs):
m = MyAOTIModule(lib_path, device)
# sanity check
m(*inputs)
return torch.jit.trace(m, inputs)
def compile_model(device, data):
module = SimpleModule().to(device)
x = torch.randn((4, 4), device=device)
inputs = (x,)
# make batch dimension
batch_dim = Dim("batch", min=1, max=1024)
dynamic_shapes = {
"x": {0: batch_dim},
}
with torch.no_grad():
# aot-compile the module into a .so pointed by lib_path
lib_path = torch._export.aot_compile(
module, inputs, dynamic_shapes=dynamic_shapes
)
script_module = make_script_module(lib_path, device, *inputs)
aoti_script_model = f"script_model_{device}.pt"
script_module.save(aoti_script_model)
# save sample inputs and ref output
with torch.no_grad():
ref_output = module(*inputs)
data.update(
{
f"inputs_{device}": list(inputs),
f"outputs_{device}": [ref_output],
}
)
def main():
data = {}
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
compile_model(device, data)
torch.jit.script(TensorSerializer(data)).save("script_data.pt")
if __name__ == "__main__":
main()
|