Skip to content

Commit a056869

Browse files
committed
def
1 parent 17022bb commit a056869

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

onnxscript/ir/passes/common/onnx_checker_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def test_check_invalid_model(self):
6969
ir_version=10,
7070
)
7171

72-
with self.assertRaisesRegex(Exception, "Field 'name' of 'graph' is required to be non-empty"):
72+
with self.assertRaisesRegex(
73+
Exception, "Field 'name' of 'graph' is required to be non-empty"
74+
):
7375
onnx_checker.CheckerPass()(model)
7476

7577

onnxscript/ir/passes/common/shape_inference.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"infer_shapes",
1010
]
1111

12-
import functools
1312
import logging
1413

1514
import onnx
@@ -72,14 +71,16 @@ def __init__(
7271
self.data_prop = data_prop
7372

7473
def call(self, model: ir.Model) -> ir.passes.PassResult:
75-
onnx_infer_shapes = functools.partial(
76-
onnx.shape_inference.infer_shapes,
77-
check_type=self.check_type,
78-
strict_mode=self.strict_mode,
79-
data_prop=self.data_prop,
80-
)
74+
def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
75+
onnx.shape_inference.infer_shapes(
76+
proto,
77+
check_type=self.check_type,
78+
strict_mode=self.strict_mode,
79+
data_prop=self.data_prop,
80+
)
81+
8182
try:
82-
inferred_model_proto = _c_api_utils.call_onnx_api(onnx_infer_shapes, model)
83+
inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
8384
except Exception as e:
8485
logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
8586
return ir.passes.PassResult(model, False)

0 commit comments

Comments
 (0)