Skip to content

Commit 7a2ac23

Browse files
authored
Merge branch 'main' into fix-2219-new
2 parents a5cc47e + 9dbf685 commit 7a2ac23

10 files changed

Lines changed: 632 additions & 13 deletions

File tree

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ include_patterns = [
3939
exclude_patterns = [
4040
'tests/**', # Skip linting test files for speed
4141
# FIXME: Fix typing annotations in these files
42+
'examples/custom_op_expansion.py',
4243
'onnxscript/converter_test.py',
4344
'onnxscript/converter.py',
4445
'onnxscript/evaluator_test.py',

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.5.6
1+
0.5.7

examples/custom_op_expansion.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
# ruff: noqa
4+
5+
"""A utility and an example showing how onnxscript functions can be used to define function expansions
6+
and be used with the inliner to replace calls to the custom function with an expanded subgraph.
7+
This is useful to perform certain classes of graph surgery easily.
8+
"""
9+
10+
import onnx
11+
12+
import onnxscript
13+
import onnxscript.utils.replace as replace
14+
15+
script = onnxscript.script
16+
FLOAT = onnxscript.FLOAT
17+
op = onnxscript.values.opset22
18+
local = onnxscript.values.Opset("local", 1)
19+
20+
21+
# Example Model: Actual models can come from ModelBuilder or Exporter or any other source.
22+
# Models can contain calls to custom operations (from a custom domain like 'local' here or
23+
# even "com.microsoft" etc.)
24+
@script()
25+
def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
26+
DoubleX = op.Add(X, X)
27+
YSquare = op.Mul(Y, Y)
28+
# Example call to a custom operation
29+
Temp1 = local.CustomOp1(DoubleX, YSquare)
30+
# Another call to a custom operation with an attribute
31+
Temp2 = local.CustomOp2(Temp1, alp=0.9)
32+
return Temp2
33+
34+
35+
# Define expansions for custom operations as onnxscript functions
36+
@script(opset=local)
37+
def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
38+
Temp1 = op.Sub(X, Y)
39+
return op.Div(Temp1, X)
40+
41+
42+
@script(opset=local)
43+
def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]:
44+
Temp2 = op.Elu(X, alpha=alp)
45+
return op.Mul(Temp2, Temp2)
46+
47+
48+
# Now, we can replace the custom operations in the model with their expansions:
49+
50+
functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()]
51+
52+
model = model_script.to_model_proto()
53+
54+
print("Original Model with custom operations:")
55+
print(onnx.printer.to_text(model))
56+
57+
58+
updated_model = replace.replace_functions(model, functions)
59+
60+
print("\nUpdated Model after replacing custom operations with their expansions:")
61+
print(onnx.printer.to_text(updated_model))

0 commit comments

Comments
 (0)