1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
# Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter."""
import io
from typing import List
import onnx
import onnxscript
from onnxscript.onnx_types import FLOAT
import torch
from torch.onnx._internal import jit_utils
from torch.testing._internal import common_utils
class TestONNXScriptExport(common_utils.TestCase):
# opset version is
# 1. local function is supported after opset 15
# 2. onnx-script requires users to determine opset in local function
opset_version = 15
def test_onnxscript_registration_with_multiple_models(self):
from onnxscript.onnx_opset import opset15 as op
# 1. Register Selu onnxscript function as custom Op
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
# default value is not supported by onnxscript
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=self.opset_version,
)
# 2. Register layer_norm onnxscript function as custom Op
@onnxscript.script(custom_opset)
def layer_norm(
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
):
mean = op.ReduceMean(X, axes=axes)
D = X - mean # op.Sub(X, mean)
DD = D * D # op.Mul(D, D)
var = op.ReduceMean(DD, axes=axes)
vareps = var + eps # op.Add(var, eps)
stddev = op.Sqrt(vareps)
invstddev = op.Reciprocal(stddev)
normalized = D * invstddev # op.Mul(D, invstddev)
normalizedw = op.CastLike(
normalized, weight
) # Type issue if missing this Op
normalizedscaled = normalizedw * weight # op.Mul(normalized, weight)
return normalizedscaled + bias
@torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def custom_layer_norm(
g, input, normalized_shape, weight, bias, eps, cudnn_enable
):
# comprehension is not supported by onnxscript
axes = [-i for i in range(len(normalized_shape), 0, -1)]
return g.onnxscript_op(
layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
).setType(input.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::layer_norm",
symbolic_fn=custom_layer_norm,
opset_version=self.opset_version,
)
# 3. export two models
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model_selu = torch.nn.SELU()
selu_onnx = io.BytesIO()
torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version)
N, C = 3, 4
y = torch.randn(N, C)
model_layer_norm = torch.nn.LayerNorm(C)
layer_norm_onnx = io.BytesIO()
torch.onnx.export(
model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version
)
# 4. test on models
selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue()))
layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue()))
self.assertEqual(len(selu_proto.functions), 1)
self.assertEqual(len(layer_norm_proto.functions), 1)
self.assertEqual(selu_proto.functions[0].name, "Selu")
self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm")
def test_loop_registration(self):
# Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py,
# which has recursive logic to go through every nodes with subgraph in model proto
class NestedLoopsModel(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.selu = torch.nn.SELU()
@torch.jit.script_method
def forward(self, x):
y = x
for i in range(x.size(3)):
if i == 0:
y = self.selu(x)
else:
y += i
return y
model = NestedLoopsModel()
inputs = torch.zeros(1, 2, 3, 4)
from onnxscript.onnx_opset import opset15 as op
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2)
@onnxscript.script(custom_opset)
def Selu(X):
alpha = 1.6732632423543772848170429916717
gamma = 1.0507009873554804934193349852946
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
def custom_selu(g, X):
# domain of the Op should be aligned with onnx-script
# setType API is required for custom Op to support
# torchscript shape type inference
print("custom_selu is used!")
return g.onnxscript_op(Selu, X).setType(X.type())
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=15,
)
saved_model = io.BytesIO()
torch.onnx.export(
torch.jit.script(model), inputs, f=saved_model, opset_version=15
)
loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue()))
self.assertEqual(len(loop_selu_proto.functions), 1)
|