1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
The objective is to replace some nodes in an onnx model into another
sequence of nodes but more efficient.
First a dummy model
===================
"""
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnxscript import ir
from onnxscript.rewriter import pattern
def get_rotary_model(bad_model=False):
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
]
nodes = [
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
oh.make_node(
"ConcatTrainingBad" if bad_model else "ConcatTraining",
["_onx_transpose0", "_onx_transpose0"],
["_onx_concattraining0", "_onx_concattraining1"],
domain="com.microsoft",
),
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
]
outputs = [
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model
model = get_rotary_model()
ir_model = ir.serde.deserialize_model(model)
####################################
# The rewriting pattern
# =====================
def rotary_match_pattern(op, x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, _length = op.ConcatTraining(
transpose, transpose, domain="com.microsoft", outputs=2
)
sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
cos = op.Cos(output)
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
return cast1, cast2
def rotary_apply_pattern(op, x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
)
return part1, part2
###########################
# The rule
# ========
#
# The rule is easy to create.
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
##########################
# Let's apply it.
rule.apply_to_model(ir_model)
########################
# And finally, we can generate the model.
rewritten_model = ir.serde.serialize_model(ir_model)
########################
# Let's see what it looks like.
for node in rewritten_model.graph.node:
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
#############################
# What if it fails?
# =================
model = get_rotary_model(True)
ir_model = ir.serde.deserialize_model(model)
rule.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
print([n.op_type for n in rewritten_model.graph.node])
################################
# The match did not happen.
# Let's increase the verbosity.
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
rule.apply_to_model(ir_model)
# TODO(rama): Update the following, the trace-printed looks different now.
######################################
# The logs shows every time the algorithm rejected a pattern.
# We can see the following:
#
# ::
#
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
# --hint--: BACKWARD: different node types
# --pattern
# ConcatTraining(transpose, transpose) -> (output, length)
# -- model
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
# iteration=1
# --marked-- #2
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
# len(stacked)=0:[]
#
# Line 673 in file `generic_pattern.py`, the match was rejected.
# It says while comparing two nodes in the backward direction,
# node types do not match.
# It also says that two nodes were actually matched.
|