1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import numpy as np
from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from onnxscript.onnx_types import FLOAT, INT64
from tests.common import testutils
class LoopOpTest(testutils.TestBase):
def test_loop(self):
"""Basic loop test."""
@script()
def sumprod(x: FLOAT["N"], N: INT64) -> (FLOAT["N"], FLOAT["N"]): # noqa: F821
sum = op.Identity(x)
prod = op.Identity(x)
for _ in range(N):
sum = sum + x
prod = prod * x
return sum, prod
self.validate(sumprod)
x = np.array([2])
M = 3
sum, prod = sumprod(x, M)
self.assertEqual(sum, np.array([8]))
self.assertEqual(prod, np.array([16]))
def test_loop_bound(self):
"""Test with an expression for loop bound."""
@script()
def sumprod(x: FLOAT["N"], N: INT64) -> (FLOAT["N"], FLOAT["N"]): # noqa: F821
sum = op.Identity(x)
prod = op.Identity(x)
for _ in range(2 * N + 1):
sum = sum + x
prod = prod * x
return sum, prod
self.validate(sumprod)
if __name__ == "__main__":
unittest.main()
|