File tree Expand file tree Collapse file tree 2 files changed +12
-9
lines changed
onnxscript/ir/passes/common Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Original file line number Diff line number Diff line change @@ -69,7 +69,9 @@ def test_check_invalid_model(self):
6969 ir_version = 10 ,
7070 )
7171
72- with self .assertRaisesRegex (Exception , "Field 'name' of 'graph' is required to be non-empty" ):
72+ with self .assertRaisesRegex (
73+ Exception , "Field 'name' of 'graph' is required to be non-empty"
74+ ):
7375 onnx_checker .CheckerPass ()(model )
7476
7577
Original file line number Diff line number Diff line change 99 "infer_shapes" ,
1010]
1111
12- import functools
1312import logging
1413
1514import onnx
@@ -72,14 +71,16 @@ def __init__(
7271 self .data_prop = data_prop
7372
7473 def call (self , model : ir .Model ) -> ir .passes .PassResult :
75- onnx_infer_shapes = functools .partial (
76- onnx .shape_inference .infer_shapes ,
77- check_type = self .check_type ,
78- strict_mode = self .strict_mode ,
79- data_prop = self .data_prop ,
80- )
74+ def partial_infer_shapes (proto : onnx .ModelProto ) -> onnx .ModelProto :
75+ onnx .shape_inference .infer_shapes (
76+ proto ,
77+ check_type = self .check_type ,
78+ strict_mode = self .strict_mode ,
79+ data_prop = self .data_prop ,
80+ )
81+
8182 try :
82- inferred_model_proto = _c_api_utils .call_onnx_api (onnx_infer_shapes , model )
83+ inferred_model_proto = _c_api_utils .call_onnx_api (partial_infer_shapes , model )
8384 except Exception as e :
8485 logger .warning ("Shape inference failed: %s. Model is left unchanged" , exc_info = e )
8586 return ir .passes .PassResult (model , False )
You can’t perform that action at this time.
0 commit comments