1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from onnxscript import script
from onnxscript.onnx_opset import opset15 as op
from onnxscript.onnx_types import FLOAT
@script()
def gemmgelu(
A: FLOAT["M", "K"], # noqa: F821
W: FLOAT["K", "N"], # noqa: F821
Bias: FLOAT["N"], # noqa: F821
) -> FLOAT["M", "N"]: # noqa: F821
a = op.Constant(value_float=0.5)
b = op.Constant(value_float=0.797885)
c = op.Constant(value_float=0.035677)
one = op.Constant(value_float=1.0)
P1 = op.MatMul(A, W)
X = op.Add(P1, Bias)
T1 = op.Mul(X, X)
T2 = op.Mul(c, T1)
T3 = op.Add(b, T2)
T4 = op.Mul(X, T3)
T5 = op.Tanh(T4)
T6 = op.Add(one, T5)
T7 = op.Mul(X, T6)
Y = op.Mul(a, T7)
return Y
|