1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Test cases for automatic introduction of CastLike around constants:
from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from onnxscript.onnx_types import BOOL, FLOAT
@script(default_opset=op)
def inc_right(A: FLOAT[...]) -> FLOAT[...]:
return A + 1
@script()
def inc_right_expanded(A: FLOAT[...]) -> FLOAT[...]:
return A + op.CastLike(1, A)
@script(default_opset=op)
def inc_left(A: FLOAT[...]) -> FLOAT[...]:
return 1 + A
@script()
def inc_left_expanded(A: FLOAT[...]) -> FLOAT[...]:
return op.CastLike(1, A) + A
@script(default_opset=op)
def cmp_zero_right(A: FLOAT[...]) -> BOOL[...]:
return A == 0
@script()
def cmp_zero_right_expanded(A: FLOAT[...]) -> BOOL[...]:
return A == op.CastLike(0, A)
@script(default_opset=op)
def cmp_zero_mright(A: FLOAT[...]) -> BOOL[...]:
return A == -11
@script()
def cmp_zero_mright_expanded(A: FLOAT[...]) -> BOOL[...]:
return A == op.CastLike(-11, A)
@script(default_opset=op)
def cmp_zero_left(A: FLOAT[...]) -> BOOL[...]:
return 0 == A
@script()
def cmp_zero_left_expanded(A: FLOAT[...]) -> BOOL[...]:
return op.CastLike(0, A) == A
@script(default_opset=op)
def div_right(A: FLOAT[...]) -> FLOAT[...]:
return A / 2
@script()
def div_right_expanded(A: FLOAT[...]) -> FLOAT[...]:
return A / op.CastLike(2, A)
@script(default_opset=op)
def div_minus_right(A: FLOAT[...]) -> FLOAT[...]:
return A / (-2)
@script()
def div_minus_right_expanded(A: FLOAT[...]) -> FLOAT[...]:
return A / op.CastLike(-2, A)
# @script()
# def div_minus_minus_right(A: FLOAT[...]) -> FLOAT[...]:
# return A / (-(-2))
@script()
def where_left(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
return op.Where(C, 2, A)
@script()
def where_left_expanded(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
return op.Where(C, op.CastLike(2, A), A)
@script()
def where_right(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
return op.Where(C, A, 3)
@script()
def where_right_expanded(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
return op.Where(C, A, op.CastLike(3, A))
|