diff --git a/onnxscript/utils/replace.py b/onnxscript/utils/replace.py index d3af1a37a0..cab32226e7 100644 --- a/onnxscript/utils/replace.py +++ b/onnxscript/utils/replace.py @@ -4,14 +4,11 @@ from typing import Sequence -import onnx import onnx_ir as ir import onnx_ir.passes.common as common_passes -def replace_functions( - model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto] -) -> onnx.ModelProto: +def replace_functions(irmodel: ir.Model, irfunctions: Sequence[ir.Function]) -> None: """A utility function to replace custom operations in a model with their expansions: Args: model: An ONNX ModelProto possibly containing calls to custom operations. @@ -20,8 +17,6 @@ def replace_functions( Returns: An updated ModelProto with custom operations replaced by their expansions. """ - irmodel = ir.from_proto(model) - irfunctions = [ir.from_proto(func) for func in functions] model_functions = irmodel.functions if len(model_functions) != 0: # Since we use inlining, check that there are no model-local functions. @@ -32,4 +27,3 @@ def replace_functions( # TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values. common_passes.InlinePass()(irmodel) common_passes.RemoveUnusedOpsetsPass()(irmodel) - return ir.to_proto(irmodel)