File: erfgelu.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (161 lines) | stat: -rw-r--r-- 4,554 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Onnx Pattern Rewriting.

This script shows how to define a rewriting rule based on patterns.

First a dummy model with a GELU activation
===================
"""

import math

import onnx

import onnxscript
from onnxscript import FLOAT, ir, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
    input_add = opset18.Add(X, Y)
    sqrt2 = opset18.Constant(value_float=math.sqrt(2))
    erf = opset18.Erf(input_add / sqrt2)
    add_const = opset18.Constant(value_float=1.0)
    plus_one = erf + add_const
    mul1 = input_add * plus_one
    mul_const = opset18.Constant(value_float=0.5)
    result = mul_const * mul1
    return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# Model demonstrating multiple patterns and variations of GELU activation
# =====================


@script()
def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
    # Create first GELU variant
    sqrt2_v1 = opset18.Constant(value_float=math.sqrt(2))
    erf_v1 = opset18.Erf(X / sqrt2_v1)
    add_const_v1 = opset18.Constant(value_float=1.0)
    plus_one_v1 = erf_v1 + add_const_v1
    mul1_v1 = X * plus_one_v1
    mul_const_v1 = opset18.Constant(value_float=0.5)
    gelu1 = mul_const_v1 * mul1_v1

    # Create second GELU variant
    sqrt2_v2 = opset18.Constant(value_float=math.sqrt(2))
    erf_v2 = opset18.Erf(Y / sqrt2_v2)
    add_const_v2 = opset18.Constant(value_float=1.0)
    plus_one_v2 = erf_v2 + add_const_v2
    mul1_v2 = Y * plus_one_v2
    mul_const_v2 = opset18.Constant(value_float=0.5)
    gelu2 = mul1_v2 * mul_const_v2

    # Add both GELU functions
    result = opset18.Add(gelu1, gelu2)
    return result


commute_model = commute_model.to_model_proto()
onnx.checker.check_model(commute_model)


####################################
# The target pattern
# =====================


def erf_gelu_pattern(op, x):
    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))


def erf_gelu_pattern_2(op, x):
    return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5


####################################
# The replacement pattern
# =====================


def gelu(op, x: ir.Value):
    return op.Gelu(x, _domain="com.microsoft")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
    rule = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=[rule],
    )
    return model_with_rewrite_applied


def apply_rewrite_with_ruleset(model):
    # Create multiple rules
    rule1 = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    rule2 = pattern.RewriteRule(
        erf_gelu_pattern_2,  # Target Pattern
        gelu,  # Replacement
    )
    # Create a Rewrite Rule Set with multiple rules.
    rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
    # Apply rewrites
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=rewrite_rule_set,
        # pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules
    )
    return model_with_rewrite_applied


def apply_rewrite_with_commute(model):
    rule = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    # Create a Rewrite Rule Set with commute=True
    rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
    # Apply rewrites
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=rewrite_rule_set,
    )
    return model_with_rewrite_applied


# Rewrite-Simple
model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(model_with_rewrite)

# Rewrite-Single-Patterns
# Incorrect number of rewrites
model_with_single_rewrite_ruleset = apply_rewrite(commute_model)
onnx.checker.check_model(model_with_single_rewrite_ruleset)

# Rewrite-Multiple-Patterns-RuleSet
model_with_rewrite_ruleset = apply_rewrite_with_ruleset(commute_model)
onnx.checker.check_model(model_with_rewrite_ruleset)

# Rewrite-Multiple-Patterns-Commute
model_with_rewrite_commute = apply_rewrite_with_commute(commute_model)
onnx.checker.check_model(model_with_rewrite_commute)