File: test.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (115 lines) | stat: -rw-r--r-- 3,441 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch._export import aot_compile
from torch.export import Dim


torch.manual_seed(1337)


class Net(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.w_pre = torch.randn(4, 4, device=device)
        self.w_add = torch.randn(4, 4, device=device)

    def forward(self, x):
        w_transpose = torch.transpose(self.w_pre, 0, 1)
        w_relu = torch.nn.functional.relu(w_transpose)
        w = w_relu + self.w_add
        return torch.matmul(x, w)


class NetWithTensorConstants(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.w = torch.randn(30, 1, device="cuda")

    def forward(self, x, y):
        z = self.w * x * y
        return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]


data = {}
data_with_tensor_constants = {}


# Basice AOTI model test generation.
def generate_basic_tests():
    for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
        for use_runtime_constant_folding in [True, False]:
            if device == "cpu" and use_runtime_constant_folding:
                # We do not test runtime const folding for cpu mode.
                continue
            model = Net(device).to(device=device)
            x = torch.randn((4, 4), device=device)
            with torch.no_grad():
                ref_output = model(x)

            torch._dynamo.reset()
            with torch.no_grad():
                dim0_x = Dim("dim0_x", min=1, max=1024)
                dynamic_shapes = {"x": {0: dim0_x}}
                model_so_path = aot_compile(
                    model,
                    (x,),
                    dynamic_shapes=dynamic_shapes,
                    options={
                        "aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
                    },
                )

            suffix = f"{device}"
            if use_runtime_constant_folding:
                suffix += "_use_runtime_constant_folding"
            data.update(
                {
                    f"model_so_path_{suffix}": model_so_path,
                    f"inputs_{suffix}": [x],
                    f"outputs_{suffix}": [ref_output],
                    f"w_pre_{suffix}": model.w_pre,
                    f"w_add_{suffix}": model.w_add,
                }
            )


# AOTI model which will create additional tensors during autograd.
def generate_test_with_additional_tensors():
    if not torch.cuda.is_available():
        return

    model = NetWithTensorConstants()
    x = torch.randn((30, 1), device="cuda")
    y = torch.randn((30, 1), device="cuda")
    with torch.no_grad():
        ref_output = model(x, y)

    torch._dynamo.reset()
    with torch.no_grad():
        model_so_path = aot_compile(model, (x, y))

    data_with_tensor_constants.update(
        {
            "model_so_path": model_so_path,
            "inputs": [x, y],
            "outputs": [ref_output],
            "w": model.w,
        }
    )


generate_basic_tests()
generate_test_with_additional_tensors()


# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
    def __init__(self, data):
        super().__init__()
        for key in data:
            setattr(self, key, data[key])


torch.jit.script(Serializer(data)).save("data.pt")
torch.jit.script(Serializer(data_with_tensor_constants)).save(
    "data_with_tensor_constants.pt"
)