Skip to content

Commit a38ddf4

Browse files
committed
update typing
1 parent 6a3cb3c commit a38ddf4

2 files changed

Lines changed: 7 additions & 9 deletions

File tree

onnxscript/ir/passes/common/_c_api_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
import logging
8-
from typing import TYPE_CHECKING, Callable
8+
from typing import TYPE_CHECKING, Callable, TypeVar
99

1010
from onnxscript import ir
1111

@@ -17,19 +17,18 @@
1717
# Temporarily remove initializers larger than this size to keep model size down
1818
# for the onnx.shape_inference call because it needs to serialize the model
1919
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
20+
_R = TypeVar("_R")
2021

2122

22-
def call_onnx_api(
23-
func: Callable[[onnx.ModelProto], onnx.ModelProto], model: ir.Model
24-
) -> onnx.ModelProto:
23+
def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
2524
"""Call an ONNX C API function by temporarily removing initializers.
2625
2726
This is necessary because the ONNX C API does not support large models
2827
with initializers that have large tensor values. The input model is left
2928
unchanged no matter the call succeeds or not.
3029
3130
Args:
32-
func: Partially applied function that takes a model proto and returns a model proto.
31+
func: Partially applied function that takes a model proto and returns anything.
3332
model: The IR model to pass to the API function.
3433
3534
Returns:
@@ -61,7 +60,7 @@ def call_onnx_api(
6160

6261
try:
6362
proto = ir.serde.serialize_model(model)
64-
result_proto = func(proto)
63+
result = func(proto)
6564
finally:
6665
# Restore the original initializer values so the model is unchanged
6766
for initializer in initializer_values:
@@ -73,4 +72,4 @@ def call_onnx_api(
7372
model.graph.inputs.clear()
7473
model.graph.inputs.extend(inputs)
7574

76-
return result_proto
75+
return result

onnxscript/ir/passes/common/onnx_checker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,14 @@ def __init__(
3939
def call(self, model: ir.Model) -> ir.passes.PassResult:
4040
"""Run the onnx checker on the model."""
4141

42-
def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto:
42+
def _partial_check_model(proto: onnx.ModelProto) -> None:
4343
"""Partial function to check the model."""
4444
onnx.checker.check_model(
4545
proto,
4646
full_check=self.full_check,
4747
skip_opset_compatibility_check=self.skip_opset_compatibility_check,
4848
check_custom_domain=self.check_custom_domain,
4949
)
50-
return proto
5150

5251
_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
5352
# The model is not modified

0 commit comments

Comments
 (0)