forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreplace.py
More file actions
29 lines (23 loc) · 1.21 KB
/
replace.py
File metadata and controls
29 lines (23 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""A utility function to replace custom operations in a model with their expansions"""
from typing import Sequence
import onnx_ir as ir
import onnx_ir.passes.common as common_passes
def replace_functions(irmodel: ir.Model, irfunctions: Sequence[ir.Function]) -> None:
"""A utility function to replace custom operations in a model with their expansions:
Args:
model: An ONNX ModelProto possibly containing calls to custom operations.
functions: A sequence of FunctionProto defining the expansions for the custom operations.
Returns:
An updated ModelProto with custom operations replaced by their expansions.
"""
model_functions = irmodel.functions
if len(model_functions) != 0:
# Since we use inlining, check that there are no model-local functions.
raise ValueError("Input model cannot have model-local functions.")
for func in irfunctions:
model_functions[func.identifier()] = func
# TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values.
common_passes.InlinePass()(irmodel)
common_passes.RemoveUnusedOpsetsPass()(irmodel)