File: cast_like.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (103 lines) | stat: -rw-r--r-- 2,175 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# Test cases for automatic introduction of CastLike around constants:

from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from onnxscript.onnx_types import BOOL, FLOAT


@script(default_opset=op)
def inc_right(A: FLOAT[...]) -> FLOAT[...]:
    return A + 1


@script()
def inc_right_expanded(A: FLOAT[...]) -> FLOAT[...]:
    return A + op.CastLike(1, A)


@script(default_opset=op)
def inc_left(A: FLOAT[...]) -> FLOAT[...]:
    return 1 + A


@script()
def inc_left_expanded(A: FLOAT[...]) -> FLOAT[...]:
    return op.CastLike(1, A) + A


@script(default_opset=op)
def cmp_zero_right(A: FLOAT[...]) -> BOOL[...]:
    return A == 0


@script()
def cmp_zero_right_expanded(A: FLOAT[...]) -> BOOL[...]:
    return A == op.CastLike(0, A)


@script(default_opset=op)
def cmp_zero_mright(A: FLOAT[...]) -> BOOL[...]:
    return A == -11


@script()
def cmp_zero_mright_expanded(A: FLOAT[...]) -> BOOL[...]:
    return A == op.CastLike(-11, A)


@script(default_opset=op)
def cmp_zero_left(A: FLOAT[...]) -> BOOL[...]:
    return 0 == A


@script()
def cmp_zero_left_expanded(A: FLOAT[...]) -> BOOL[...]:
    return op.CastLike(0, A) == A


@script(default_opset=op)
def div_right(A: FLOAT[...]) -> FLOAT[...]:
    return A / 2


@script()
def div_right_expanded(A: FLOAT[...]) -> FLOAT[...]:
    return A / op.CastLike(2, A)


@script(default_opset=op)
def div_minus_right(A: FLOAT[...]) -> FLOAT[...]:
    return A / (-2)


@script()
def div_minus_right_expanded(A: FLOAT[...]) -> FLOAT[...]:
    return A / op.CastLike(-2, A)


# @script()
# def div_minus_minus_right(A: FLOAT[...]) -> FLOAT[...]:
#     return A / (-(-2))


@script()
def where_left(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
    return op.Where(C, 2, A)


@script()
def where_left_expanded(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
    return op.Where(C, op.CastLike(2, A), A)


@script()
def where_right(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
    return op.Where(C, A, 3)


@script()
def where_right_expanded(C: BOOL[...], A: FLOAT[...]) -> FLOAT[...]:
    return op.Where(C, A, op.CastLike(3, A))