1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import onnxscript.testing
from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from tests.common import testutils
class IfOpTest(testutils.TestBase):
def test_no_else(self):
"""Basic test for if-then without else."""
# TODO: pass default opset as parameter to @script
@script()
def if1(cond, x, y):
result = op.Identity(y)
if cond:
result = op.Identity(x)
return result
# if1 should be treated as equivalent to the code if2 below
@script()
def if2(cond, x, y):
result = op.Identity(y)
if cond:
result = op.Identity(x)
else:
result = op.Identity(result)
return result
# if2 should be treated as equivalent to the code if3 below (SSA renaming)
@script()
def if3(cond, x, y):
result1 = op.Identity(y)
if cond:
result2 = op.Identity(x)
else:
result2 = op.Identity(result1)
return result2
onnxscript.testing.assert_isomorphic_function(if1, if2)
onnxscript.testing.assert_isomorphic_function(if2, if3)
if __name__ == "__main__":
unittest.main()
|