1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Features included:
# Overloaded operators such as <=, +, /
# Nested expressions
# Type annotation for attributes (to distinguish inputs and attributes)
# Importing (predefined) opsets
# Features not yet covered:
# Auto-cast for constants/attributes. (Must include explicit CastLike ops)
# No default-values for attributes yet, which are currently present in the opschema definition (in ONNX).
# No default-values for inputs yet.
# Element-type annotation for tensors
from onnxscript import script
from onnxscript.onnx_opset import opset18 as op
@script()
def Relu(X):
zero = op.CastLike(0, X)
return op.Max(X, zero)
@script()
def Selu(X, alpha: float, gamma: float):
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
@script()
def Elu(X, alpha: float):
alphaX = op.CastLike(alpha, X)
zero = op.CastLike(0, X)
one = op.CastLike(1, X)
return op.Where(X < zero, alphaX * (op.Exp(X) - one), X)
@script()
def ThresholdedRelu(X, alpha: float):
zero = op.CastLike(0, X)
alphaX = op.CastLike(alpha, X)
return op.Where(X > alphaX, X, zero)
@script()
def LeakyRelu(X, alpha: float):
zero = op.CastLike(0, X)
alphaX = op.CastLike(alpha, X)
return op.Where(X < zero, alphaX * X, X)
@script()
def PRelu(X, slope):
zero = op.CastLike(0, X)
return op.Where(X < zero, slope * X, X)
@script()
def HardSigmoid(X, alpha: float, beta: float):
zero = op.CastLike(0, X)
one = op.CastLike(1, X)
alphaX = op.CastLike(alpha, X)
betaX = op.CastLike(beta, X)
return op.Max(zero, op.Min(one, alphaX * X + betaX))
@script()
def Shrink(x, bias: float, lambd: float):
zero = op.CastLike(0, x)
return op.Where(x < -lambd, x + bias, op.Where(x > lambd, x - bias, zero))
@script()
def Softplus(X):
one = op.CastLike(1, X)
return op.Log(op.Exp(X) + one)
@script()
def Softsign(X):
one = op.CastLike(1, X)
return X / (one + op.Abs(X))
from onnxscript.onnx_types import BOOL, FLOAT
@script()
def Clip(input: FLOAT[...], min: FLOAT = None, max: FLOAT = None) -> FLOAT[...]:
result = input
if op.OptionalHasElement(min):
result = op.Where(result < min, min, result)
if op.OptionalHasElement(max):
result = op.Where(result > max, max, result)
return result
# @script()
# def Clip(input, min, max):
# return op.Where(input < min, min, op.Where(input > max, max, input))
@script()
def OptionalHasElement(input: FLOAT[...]) -> BOOL:
return op.OptionalHasElement(input)
|