1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from onnx import TensorProto
from onnx.helper import make_tensor
from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from onnxscript.onnx_types import FLOAT, INT64
@script()
def maxsum(A: FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"]:
sum1 = op.ReduceSum(A)
sum2 = op.ReduceSum(B)
if sum1 < sum2:
result = op.Identity(B)
else:
result = op.Identity(A)
return result
# Test inference of inputs/outputs for then/else blocks:
@script()
def maxsum2(A: FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"]:
sum1 = op.ReduceSum(A)
sum2 = op.ReduceSum(B)
if sum1 < sum2:
temp = op.Identity(B)
result = op.Identity(temp)
else:
temp = op.Identity(A)
result = op.Identity(temp)
return result
# test variables assigned only in one branch
@script()
def maxsum3(A: FLOAT["N"], B: FLOAT["N"]) -> FLOAT["N"]:
sum1 = op.ReduceSum(A)
sum2 = op.ReduceSum(B)
result = op.Identity(A)
if sum1 < sum2:
result = op.Identity(B)
return result
@script()
def check_equal(x: FLOAT[None, None], axis: INT64[1]) -> FLOAT[None, None]:
zero = op.Constant(value=make_tensor("zero", TensorProto.INT64, [1], [0]))
if axis == zero:
result = op.Transpose(x, perm=[1, 0])
else:
result = op.Identity(x)
return result
@script()
def check_less_or_equal(x: FLOAT[None, None], axis: INT64[1]) -> FLOAT[None, None]:
zero = op.Constant(value=make_tensor("zero", TensorProto.INT64, [1], [0]))
if axis <= zero:
result = op.Transpose(x, perm=[1, 0])
else:
result = op.Identity(x)
return result
@script()
def check_greater(x: FLOAT[None, None], axis: INT64[1]) -> FLOAT[None, None]:
zero = op.Constant(value=make_tensor("zero", TensorProto.INT64, [1], [0]))
if axis > zero:
result = op.Transpose(x, perm=[1, 0])
else:
result = op.Identity(x)
return result
@script()
def check_greater_or_equal(x: FLOAT[None, None], axis: INT64[1]) -> FLOAT[None, None]:
zero = op.Constant(value=make_tensor("zero", TensorProto.INT64, [1], [0]))
if axis >= zero:
result = op.Transpose(x, perm=[1, 0])
else:
result = op.Identity(x)
return result
@script()
def check_not(x: FLOAT[None, None], axis: INT64[1]) -> FLOAT[None, None]:
zero = op.Constant(value=make_tensor("zero", TensorProto.INT64, [1], [0]))
if not (axis >= zero):
result = op.Transpose(x, perm=[1, 0])
else:
result = op.Identity(x)
return result
@script()
def check_different(x: FLOAT[None, None], axis: INT64[1]) -> FLOAT[None, None]:
zero = op.Constant(value=make_tensor("zero", TensorProto.INT64, [1], [0]))
if axis != zero:
result = op.Transpose(x, perm=[1, 0])
else:
result = op.Identity(x)
return result
|