2020
2121
2222def call_onnx_api (
23- func : Callable [[onnx .ModelProto ], onnx .ModelProto ],
24- model : ir .Model ,
25- merge_func : Callable [[ir .Model , onnx .ModelProto ], tuple [ir .Model , bool ]],
26- ) -> tuple [ir .Model , bool ]:
23+ func : Callable [[onnx .ModelProto ], onnx .ModelProto ], model : ir .Model
24+ ) -> onnx .ModelProto :
2725 """Call an ONNX C API function by temporarily removing initializers.
26+
2827 This is necessary because the ONNX C API does not support large models
29- with initializers that have large tensor values.
28+ with initializers that have large tensor values. The input model is left
29+ unchanged no matter the call succeeds or not.
3030
3131 Args:
3232 func: Partially applied function that takes a model proto and returns a model proto.
3333 model: The IR model to pass to the API function.
34- merge_func: Function that merges IR model with information from the model proto.
3534
3635 Returns:
37- A tuple containing the modified model and a boolean indicating whether the model was modified .
36+ The resulting ModelProto that contains the result of the API call .
3837 """
3938
4039 # Store the original initializer values so they can be restored
4140 initializer_values = tuple (model .graph .initializers .values ())
4241 tensors = {v .name : v .const_value for v in initializer_values }
4342 original_inputs_len = len (model .graph .inputs )
44- initializer_names = {v .name for v in initializer_values }
4543
4644 # Turn the initializers into inputs and clear the initializers
4745 # to limit the model size
@@ -64,20 +62,15 @@ def call_onnx_api(
6462 try :
6563 proto = ir .serde .serialize_model (model )
6664 result_proto = func (proto )
67- except Exception : # pylint: disable=broad-exception-caught
68- logger .warning ("Call to %s failed. The model is not modified" , func , exc_info = True )
69- return (model , False )
7065 finally :
7166 # Restore the original initializer values so the model is unchanged
7267 for initializer in initializer_values :
73- if initializer .name in initializer_names :
74- initializer .const_value = tensors [initializer .name ]
75- model .graph .register_initializer (initializer )
68+ initializer .const_value = tensors [initializer .name ]
69+ model .graph .register_initializer (initializer )
7670
7771 # Restore the original inputs
7872 inputs = model .graph .inputs [:original_inputs_len ]
7973 model .graph .inputs .clear ()
8074 model .graph .inputs .extend (inputs )
8175
82- # Merge the result with the original model
83- return merge_func (model , result_proto )
76+ return result_proto
0 commit comments