File: indirect_assert_helper.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (79 lines) | stat: -rw-r--r-- 1,680 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import sys

import torch
from torch.testing._internal.inductor_utils import GPU_TYPE


def first_arg(x, y):
    return x[y]


def second_arg(x, y):
    return x[:, y]


def same_pm_one(x, y):
    return x[y + 1, y - 1]


def same_pp_one(x, y):
    return x[y + 1, y + 1]


def store(x, y, z):
    x[y + 1, y + 1] = z


def upper1(x):
    b = torch.arange(4, device=x.device)
    return x[b]


def lower1(x):
    b = x.new_full((), -4, dtype=torch.int64)
    return x[b]


def upper2(x):
    b = x.new_full((), 4, dtype=torch.int64)
    return x[b]


def lower2(x):
    b = x.new_zeros((), dtype=torch.int64)
    return x[b - 4]


if __name__ == "__main__":
    fns = [
        name
        for name, obj in locals().items()
        if callable(obj) and obj.__module__ == __name__
    ]

    _, fn_name, dims, dyn_shape, one_size = sys.argv
    assert fn_name in fns
    assert one_size in ("True", "False")
    one_size = one_size == "True"
    assert dims in ("2", "3")
    shape_x = [3, 2, 4] if dims == "3" else [3, 2]
    if one_size:
        assert (
            fn_name == "first_arg"
        ), "only first_arg can be tested for a special case of 1-size tensor"
        shape_x[0] = 1
    assert dyn_shape in ("True", "False")
    dynamic_shapes = dyn_shape == "True"

    x = torch.randn(shape_x, device=GPU_TYPE)
    y = torch.arange(4, device=GPU_TYPE)
    fn = vars()[fn_name]
    fn = torch.compile(dynamic=dynamic_shapes)(fn)
    if fn_name == "store":
        shape = (y.numel(),) + x.shape[2:]
        z = torch.randn(shape, device=GPU_TYPE)
        fn(x, y, z)
    elif fn_name in ("upper1", "upper2", "lower1", "lower2"):
        fn(x)
    else:
        fn(x, y)