55from __future__ import annotations
66
77import logging
8- from typing import TYPE_CHECKING , Callable
8+ from typing import TYPE_CHECKING , Callable , TypeVar
99
1010from onnxscript import ir
1111
1717# Temporarily remove initializers larger than this size to keep model size down
1818# for the onnx.shape_inference call because it needs to serialize the model
1919_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
20+ _R = TypeVar ("_R" )
2021
2122
22- def call_onnx_api (
23- func : Callable [[onnx .ModelProto ], onnx .ModelProto ], model : ir .Model
24- ) -> onnx .ModelProto :
23+ def call_onnx_api (func : Callable [[onnx .ModelProto ], _R ], model : ir .Model ) -> _R :
2524 """Call an ONNX C API function by temporarily removing initializers.
2625
2726 This is necessary because the ONNX C API does not support large models
2827 with initializers that have large tensor values. The input model is left
2928 unchanged no matter the call succeeds or not.
3029
3130 Args:
32- func: Partially applied function that takes a model proto and returns a model proto .
31+ func: Partially applied function that takes a model proto and returns anything .
3332 model: The IR model to pass to the API function.
3433
3534 Returns:
@@ -61,7 +60,7 @@ def call_onnx_api(
6160
6261 try :
6362 proto = ir .serde .serialize_model (model )
64- result_proto = func (proto )
63+ result = func (proto )
6564 finally :
6665 # Restore the original initializer values so the model is unchanged
6766 for initializer in initializer_values :
@@ -73,4 +72,4 @@ def call_onnx_api(
7372 model .graph .inputs .clear ()
7473 model .graph .inputs .extend (inputs )
7574
76- return result_proto
75+ return result
0 commit comments