1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import numpy as np
from tests.common import onnx_script_test_case
from tests.models import if_statement
class TestOnnxIf(onnx_script_test_case.OnnxScriptTestCase):
def test_if(self):
n = 8
np.random.seed(0)
a = np.random.rand(n).astype("float32").T
b = np.random.rand(n).astype("float32").T
# FIXME(liqunfu): expected are from ort evaluation.
# needs numpy oxs to provide expected instead.
expected = np.array(
[
0.5488135,
0.71518934,
0.60276335,
0.5448832,
0.4236548,
0.6458941,
0.4375872,
0.891773,
],
dtype=np.float32,
)
cases = [
onnx_script_test_case.FunctionTestParams(if_statement.maxsum, [a, b], [expected])
]
for case in cases:
# FAIL : Node () Op (local_function) [TypeInferenceError]
# GraphProto attribute inferencing is not enabled
# in this InferenceContextImpl instance.
# self.run_converter_test(case)
self.run_eager_test(case)
if __name__ == "__main__":
unittest.main()
|