File: test_python_ir.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (102 lines) | stat: -rw-r--r-- 3,333 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Owner(s): ["oncall: jit"]

import unittest

import numpy as np

import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestPythonIr(JitTestCase):
    def test_param_strides(self):
        def trace_me(arg):
            return arg

        t = torch.zeros(1, 3, 16, 16)
        traced = torch.jit.trace(trace_me, t)
        value = list(traced.graph.param_node().outputs())[0]
        real_strides = list(t.stride())
        type_strides = value.type().strides()
        self.assertEqual(real_strides, type_strides)

    def test_permute_inputs_binding(self):
        @torch.jit.script
        def foo(i, j, k):
            pass

        g = foo.graph

        idxs = []
        for i, inp in enumerate(g.inputs()):
            inp.setDebugName(f"inp{i}")
            idxs.append(i)

        permuted_idxs = list(np.random.permutation(idxs))
        g.permuteInputs(permuted_idxs)
        for i, inp in enumerate(g.inputs()):
            self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())

    @unittest.skipIf(IS_MACOS, "Failing on MacOS only")
    def test_python_ir_utils(self):
        @torch.jit.script
        def foo(inp):
            x = inp + 1
            y = x / 2
            z = y * y
            return z

        add_node = foo.graph.findNode("aten::add")
        div_node = foo.graph.findNode("aten::div")

        with foo.graph.insert_point_guard(add_node):
            with foo.graph.insert_point_guard(div_node):
                foo.graph.insertConstant("goodbye")
            foo.graph.insertConstant("hello")
        with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
            foo.graph.insertConstant("hello")
        FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)

        self.assertTrue(add_node.matches(add_node.schema()))
        self.assertFalse(add_node.matches(div_node.schema()))

    def test_python_ir_utils_graph(self):
        @torch.jit.script
        def unrolled_mul(x: torch.Tensor, y: int):
            out = x
            for _ in range(y - 1):
                out = out + x
            return out

        @torch.jit.script
        def foo(x):
            return x * 4

        g = foo.graph
        muls = g.findAllNodes("aten::mul")
        scalar_muls = filter(
            lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
        )
        mul_constant_int = filter(
            lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
        )
        for mul in mul_constant_int:
            with g.insert_point_guard(mul):
                outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
                assert len(outputs) == len(list(mul.outputs()))
                for new_out, old_out in zip(outputs, g.outputs()):
                    old_out.replaceAllUsesWith(new_out)
                mul.destroy()

        FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
        self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)