1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
import unittest
import benchmark_cpp_extension # noqa: F401
import torch
class TestConsumeOp(unittest.TestCase):
def test_jit_consume_op(self):
iters = 6
def foo(x):
for i in range(iters):
result = torch.ops.operator_benchmark._consume(torch.sum(x))
return result
r = torch.jit.trace(foo, (torch.rand(2, 2)))
graph = str(r.graph)
occurance = graph.count("aten::sum")
x = torch.rand(2, 2)
value = r(x)
self.assertEqual(value, torch.sum(x))
self.assertEqual(occurance, iters)
def test_jit_consume_op_for_list_input(self):
iters = 6
def foo(x):
for i in range(iters):
result = torch.ops.operator_benchmark._consume(torch.chunk(x, 2))
return result
r = torch.jit.trace(foo, torch.rand(2, 2))
graph = str(r.graph)
occurance = graph.count("aten::chunk")
x = torch.rand(2, 2)
value = r(x)
self.assertTrue(
all([torch.allclose(t1, t2) for t1, t2 in zip(value, torch.chunk(x, 2))])
)
self.assertEqual(occurance, iters)
|