forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpattern_rewriting.py
More file actions
143 lines (110 loc) · 4.16 KB
/
pattern_rewriting.py
File metadata and controls
143 lines (110 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
The objective is to replace some nodes in an onnx model into another
sequence of nodes but more efficient.
First a dummy model
===================
"""
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnxscript import ir
from onnxscript.rewriter import pattern
def get_rotary_model(bad_model=False):
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
]
nodes = [
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
oh.make_node(
"ConcatTrainingBad" if bad_model else "ConcatTraining",
["_onx_transpose0", "_onx_transpose0"],
["_onx_concattraining0", "_onx_concattraining1"],
domain="com.microsoft",
),
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
]
outputs = [
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model
model = get_rotary_model()
ir_model = ir.serde.deserialize_model(model)
####################################
# The rewriting pattern
# =====================
def rotary_match_pattern(op, x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)
matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, _length = op.ConcatTraining(
transpose, transpose, domain="com.microsoft", outputs=2
)
sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
cos = op.Cos(output)
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
return cast1, cast2
def rotary_apply_pattern(op, x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
)
return part1, part2
###########################
# The rule
# ========
#
# The rule is easy to create.
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
##########################
# Let's apply it.
rule.apply_to_model(ir_model)
########################
# And finally, we can generate the model.
rewritten_model = ir.serde.serialize_model(ir_model)
########################
# Let's see what it looks like.
for node in rewritten_model.graph.node:
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
#############################
# What if it fails?
# =================
model = get_rotary_model(True)
ir_model = ir.serde.deserialize_model(model)
rule.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)
print([n.op_type for n in rewritten_model.graph.node])
################################
# The match did not happen.
# Let's increase the verbosity.
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
rule.apply_to_model(ir_model)