Skip to content

Commit f1cfcef

Browse files
committed
Update
1 parent 208ed94 commit f1cfcef

File tree

3 files changed

+24
-34
lines changed

3 files changed

+24
-34
lines changed

onnxscript/ir/passes/common/_c_api_utils.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,26 @@
2020

2121

2222
def call_onnx_api(
23-
func: Callable[[onnx.ModelProto], onnx.ModelProto],
24-
model: ir.Model,
25-
merge_func: Callable[[ir.Model, onnx.ModelProto], tuple[ir.Model, bool]],
26-
) -> tuple[ir.Model, bool]:
23+
func: Callable[[onnx.ModelProto], onnx.ModelProto], model: ir.Model
24+
) -> onnx.ModelProto:
2725
"""Call an ONNX C API function by temporarily removing initializers.
26+
2827
This is necessary because the ONNX C API does not support large models
29-
with initializers that have large tensor values.
28+
with initializers that have large tensor values. The input model is left
29+
unchanged no matter the call succeeds or not.
3030
3131
Args:
3232
func: Partially applied function that takes a model proto and returns a model proto.
3333
model: The IR model to pass to the API function.
34-
merge_func: Function that merges IR model with information from the model proto.
3534
3635
Returns:
37-
A tuple containing the modified model and a boolean indicating whether the model was modified.
36+
The resulting ModelProto that contains the result of the API call.
3837
"""
3938

4039
# Store the original initializer values so they can be restored
4140
initializer_values = tuple(model.graph.initializers.values())
4241
tensors = {v.name: v.const_value for v in initializer_values}
4342
original_inputs_len = len(model.graph.inputs)
44-
initializer_names = {v.name for v in initializer_values}
4543

4644
# Turn the initializers into inputs and clear the initializers
4745
# to limit the model size
@@ -64,20 +62,15 @@ def call_onnx_api(
6462
try:
6563
proto = ir.serde.serialize_model(model)
6664
result_proto = func(proto)
67-
except Exception: # pylint: disable=broad-exception-caught
68-
logger.warning("Call to %s failed. The model is not modified", func, exc_info=True)
69-
return (model, False)
7065
finally:
7166
# Restore the original initializer values so the model is unchanged
7267
for initializer in initializer_values:
73-
if initializer.name in initializer_names:
74-
initializer.const_value = tensors[initializer.name]
75-
model.graph.register_initializer(initializer)
68+
initializer.const_value = tensors[initializer.name]
69+
model.graph.register_initializer(initializer)
7670

7771
# Restore the original inputs
7872
inputs = model.graph.inputs[:original_inputs_len]
7973
model.graph.inputs.clear()
8074
model.graph.inputs.extend(inputs)
8175

82-
# Merge the result with the original model
83-
return merge_func(model, result_proto)
76+
return result_proto

onnxscript/ir/passes/common/debugging.py renamed to onnxscript/ir/passes/common/onnx_checker.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto:
4949
)
5050
return proto
5151

52-
_c_api_utils.call_onnx_api(
53-
func=_partial_check_model,
54-
model=model,
55-
# Since we do not modify the model. merge_func is not used but provided for completeness
56-
merge_func=lambda m, proto: (m, False),
57-
)
52+
_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
5853
# The model is not modified
5954
return ir.passes.PassResult(model, False)

onnxscript/ir/passes/common/shape_inference.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> tuple[ir.Model, bool]:
24-
"""Merge the inferred model with the original model.
23+
def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
24+
"""Merge the shape inferred model with the original model.
2525
2626
Args:
2727
model: The original IR model.
28-
inferred_proto: The inferred ONNX model.
28+
inferred_proto: The ONNX model with shapes and types inferred.
2929
3030
Returns:
3131
A tuple containing the modified model and a boolean indicating whether the model was modified.
@@ -48,7 +48,7 @@ def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> tuple[ir.Mo
4848
logger.warning(
4949
"Value %s not found in inferred graph %s", name, inferred_graph.name
5050
)
51-
return model, modified
51+
return modified
5252

5353

5454
class ShapeInferencePass(ir.passes.InPlacePass):
@@ -59,6 +59,8 @@ def __init__(
5959
) -> None:
6060
"""Initialize the shape inference pass.
6161
62+
If inference fails, the model is left unchanged.
63+
6264
Args:
6365
check_type: If True, check the types of the inputs and outputs.
6466
strict_mode: If True, use strict mode for shape inference.
@@ -76,14 +78,14 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
7678
strict_mode=self.strict_mode,
7779
data_prop=self.data_prop,
7880
)
79-
80-
inferred_model, modified = _c_api_utils.call_onnx_api(
81-
onnx_infer_shapes,
82-
model,
83-
merge_func=_merge_func,
84-
)
85-
86-
return ir.passes.PassResult(inferred_model, modified=modified)
81+
try:
82+
inferred_model_proto = _c_api_utils.call_onnx_api(onnx_infer_shapes, model)
83+
except Exception as e:
84+
logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
85+
return ir.passes.PassResult(model, False)
86+
87+
modified = _merge_func(model, inferred_model_proto)
88+
return ir.passes.PassResult(model, modified=modified)
8789

8890

8991
def infer_shapes(

0 commit comments

Comments
 (0)