File: compile_model.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (95 lines) | stat: -rw-r--r-- 2,378 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torch.export import Dim


# custom op that loads the aot-compiled model
AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so"
torch.classes.load_library(AOTI_CUSTOM_OP_LIB)


class TensorSerializer(torch.nn.Module):
    def __init__(self, data):
        super().__init__()
        for key in data:
            setattr(self, key, data[key])


class SimpleModule(torch.nn.Module):
    """
    a simple module to be compiled
    """

    def __init__(self) -> None:
        super().__init__()
        self.fc = torch.nn.Linear(4, 6)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        a = self.fc(x)
        b = self.relu(a)
        return b


class MyAOTIModule(torch.nn.Module):
    """
    a wrapper nn.Module that instantiates its forward method
    on MyAOTIClass
    """

    def __init__(self, lib_path, device):
        super().__init__()
        self.aoti_custom_op = torch.classes.aoti.MyAOTIClass(
            lib_path,
            device,
        )

    def forward(self, *x):
        outputs = self.aoti_custom_op.forward(x)
        return tuple(outputs)


def make_script_module(lib_path, device, *inputs):
    m = MyAOTIModule(lib_path, device)
    # sanity check
    m(*inputs)
    return torch.jit.trace(m, inputs)


def compile_model(device, data):
    module = SimpleModule().to(device)
    x = torch.randn((4, 4), device=device)
    inputs = (x,)
    # make batch dimension
    batch_dim = Dim("batch", min=1, max=1024)
    dynamic_shapes = {
        "x": {0: batch_dim},
    }
    with torch.no_grad():
        # aot-compile the module into a .so pointed by lib_path
        lib_path = torch._export.aot_compile(
            module, inputs, dynamic_shapes=dynamic_shapes
        )
    script_module = make_script_module(lib_path, device, *inputs)
    aoti_script_model = f"script_model_{device}.pt"
    script_module.save(aoti_script_model)

    # save sample inputs and ref output
    with torch.no_grad():
        ref_output = module(*inputs)
    data.update(
        {
            f"inputs_{device}": list(inputs),
            f"outputs_{device}": [ref_output],
        }
    )


def main():
    data = {}
    for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
        compile_model(device, data)
    torch.jit.script(TensorSerializer(data)).save("script_data.pt")


if __name__ == "__main__":
    main()