1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
# Owner(s): ["oncall: jit"]
import unittest
import numpy as np
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestPythonIr(JitTestCase):
def test_param_strides(self):
def trace_me(arg):
return arg
t = torch.zeros(1, 3, 16, 16)
traced = torch.jit.trace(trace_me, t)
value = list(traced.graph.param_node().outputs())[0]
real_strides = list(t.stride())
type_strides = value.type().strides()
self.assertEqual(real_strides, type_strides)
def test_permute_inputs_binding(self):
@torch.jit.script
def foo(i, j, k):
pass
g = foo.graph
idxs = []
for i, inp in enumerate(g.inputs()):
inp.setDebugName(f"inp{i}")
idxs.append(i)
permuted_idxs = list(np.random.permutation(idxs))
g.permuteInputs(permuted_idxs)
for i, inp in enumerate(g.inputs()):
self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())
@unittest.skipIf(IS_MACOS, "Failing on MacOS only")
def test_python_ir_utils(self):
@torch.jit.script
def foo(inp):
x = inp + 1
y = x / 2
z = y * y
return z
add_node = foo.graph.findNode("aten::add")
div_node = foo.graph.findNode("aten::div")
with foo.graph.insert_point_guard(add_node):
with foo.graph.insert_point_guard(div_node):
foo.graph.insertConstant("goodbye")
foo.graph.insertConstant("hello")
with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
foo.graph.insertConstant("hello")
FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)
self.assertTrue(add_node.matches(add_node.schema()))
self.assertFalse(add_node.matches(div_node.schema()))
def test_python_ir_utils_graph(self):
@torch.jit.script
def unrolled_mul(x: torch.Tensor, y: int):
out = x
for _ in range(y - 1):
out = out + x
return out
@torch.jit.script
def foo(x):
return x * 4
g = foo.graph
muls = g.findAllNodes("aten::mul")
scalar_muls = filter(
lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
)
mul_constant_int = filter(
lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
)
for mul in mul_constant_int:
with g.insert_point_guard(mul):
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
assert len(outputs) == len(list(mul.outputs()))
for new_out, old_out in zip(outputs, g.outputs()):
old_out.replaceAllUsesWith(new_out)
mul.destroy()
FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
|