1414import onnxscript .onnx_types as ot
1515from onnxscript import ir
1616from onnxscript .onnx_opset import opset18
17+ from onnxscript .rewriter import MatchingTracer , testing
18+ from onnxscript .rewriter import pattern as orp
1719from onnxscript .rewriter .rules .common import _basic_rules
1820
1921FLOAT = onnx .TensorProto .FLOAT
@@ -29,6 +31,10 @@ def _make_model(*args, **kwargs) -> ir.Model:
2931 return ir .serde .deserialize_model (onnx .helper .make_model (* args , ** kwargs ))
3032
3133
34+ def clone_model (model : ir .Model ) -> ir .Model :
35+ return ir .from_proto (ir .to_proto (model ))
36+
37+
3238class BasicRulesTest (unittest .TestCase ):
3339 def _get_random_inputs (self , model : onnx .ModelProto ) -> dict [str , Any ]:
3440 feeds : dict [str , Any ] = {}
@@ -318,65 +324,6 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
318324 self .assertEqual (["Constant" , "Unsqueeze" ], [n .op_type for n in model .graph ])
319325 self ._check_model (model_proto , rewritten_model )
320326
321- @parameterized .parameterized .expand (
322- [
323- (
324- "double_reshape_1" ,
325- _make_model (
326- onnx .helper .make_graph (
327- [
328- onnx .helper .make_node ("Reshape" , ["X" , "shape_" ], ["Xu" ]),
329- onnx .helper .make_node ("Reshape" , ["Xu" , "shape" ], ["Y" ]),
330- ],
331- "name" ,
332- [onnx .helper .make_tensor_value_info ("X" , FLOAT , [3 , 4 , 5 ])],
333- [onnx .helper .make_tensor_value_info ("Y" , FLOAT , [5 , 4 , 3 ])],
334- [
335- onnx .numpy_helper .from_array (
336- np .array ([4 , 5 , 3 ], dtype = np .int64 ), name = "shape_"
337- ),
338- onnx .numpy_helper .from_array (
339- np .array ([5 , 4 , 3 ], dtype = np .int64 ), name = "shape"
340- ),
341- ],
342- ),
343- opset_imports = [onnx .helper .make_opsetid ("" , 18 )],
344- ),
345- ),
346- (
347- "double_reshape_2" ,
348- _make_model (
349- onnx .helper .make_graph (
350- [
351- onnx .helper .make_node ("Reshape" , ["X" , "shape_" ], ["Xu" ]),
352- onnx .helper .make_node ("Reshape" , ["Xu" , "shape" ], ["Y" ]),
353- ],
354- "name" ,
355- [onnx .helper .make_tensor_value_info ("X" , FLOAT , [3 , 4 , 5 ])],
356- [onnx .helper .make_tensor_value_info ("Y" , FLOAT , [5 , 4 , 3 ])],
357- [
358- onnx .numpy_helper .from_array (
359- np .array ([- 1 ], dtype = np .int64 ), name = "shape_"
360- ),
361- onnx .numpy_helper .from_array (
362- np .array ([5 , 4 , 3 ], dtype = np .int64 ), name = "shape"
363- ),
364- ],
365- ),
366- opset_imports = [onnx .helper .make_opsetid ("" , 18 )],
367- ),
368- ),
369- ]
370- )
371- def test_reshape_reshape_rule (self , _ : str , model : ir .Model ):
372- rule_set = _basic_rules .basic_optimization_rules ()
373- model_proto = ir .serde .serialize_model (model )
374- rule_set .apply_to_model (model )
375- rewritten_model = ir .serde .serialize_model (model )
376-
377- self .assertEqual (["Reshape" ], [n .op_type for n in model .graph ])
378- self ._check_model (model_proto , rewritten_model )
379-
380327 @classmethod
381328 def _slices_split_models (cls ):
382329 models = [
@@ -465,5 +412,74 @@ def model3(X: ot.FLOAT[1, 1]):
465412 check (model3 , 0 )
466413
467414
415+ class ReshapeReshapeTest (unittest .TestCase ):
416+ @staticmethod
417+ def create_model (input_shape , shape1 , shape2 ):
418+ def _convert_shape (shape , name ):
419+ if isinstance (shape , np .ndarray ):
420+ shape = tape .initializer (ir .Tensor (shape , name = name ))
421+ elif isinstance (shape , (list , tuple )):
422+ shape = ir .Input (name , ir .Shape (shape ), ir .TensorType (ir .DataType .INT64 ))
423+ tape .graph_like .inputs .append (shape )
424+ else :
425+ raise TypeError (f"Unsupported type { type (shape )} for shape." )
426+ return shape
427+
428+ x = ir .Input ("X" , ir .Shape (input_shape ), ir .TensorType (ir .DataType .FLOAT ))
429+ y = ir .Input ("Y" , type = ir .TensorType (ir .DataType .FLOAT ))
430+ tape = ir .tape .Tape (ir .Graph ([x ], [y ], nodes = [], opset_imports = {"" : 20 }))
431+
432+ # Build the graph.
433+ reshape = tape .op ("Reshape" , inputs = [x , _convert_shape (shape1 , "shape_" )])
434+ tape .op ("Reshape" , inputs = [reshape , _convert_shape (shape2 , "shape" )], output = y )
435+ model = ir .Model (tape .graph_like , ir_version = 10 )
436+ return model
437+
438+ @parameterized .parameterized .expand (
439+ [
440+ ((3 , 4 , 5 ), [4 , 5 , 3 ], [5 , 4 , 3 ]),
441+ ((3 , 4 , 5 ), [4 , 5 , 3 ], [5 , 4 , 3 ]),
442+ ]
443+ )
444+ def test_reshape_reshape_rule (self , input_shape , shape1 , shape2 ):
445+ model = self .create_model (
446+ input_shape , np .array (shape1 , dtype = "int64" ), np .array (shape2 , dtype = "int64" )
447+ )
448+ updated_model = clone_model (model )
449+
450+ # check rewrite approach.
451+ count = _basic_rules .reshape_reshape_rule .apply_to_model (updated_model )
452+ self .assertEqual (count , 1 )
453+ self .assertEqual (["Reshape" ], [n .op_type for n in updated_model .graph ])
454+
455+ # Check inference.
456+ inputs = np .random .default_rng (10 ).random (input_shape , dtype = "float32" )
457+ testing .assert_numerically_equal (model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
458+
459+ @parameterized .parameterized .expand (
460+ [
461+ ((2 ,), np .array ([1 , 6 ], dtype = "int64" ), "ignored is not a constant" ),
462+ (np .array ([1 , 6 ], dtype = "int64" ), (3 ,), "is not a constant" ),
463+ (
464+ np .array ([1 , 6 ], dtype = "int64" ),
465+ np .array ([0 , 6 ], dtype = "int64" ),
466+ "non-positive values" ,
467+ ),
468+ ]
469+ )
470+ def test_unsupported_reshape_reshape (self , shape1 , shape2 , error_msg ):
471+ model = self .create_model ((1 , 2 , 3 ), shape1 , shape2 )
472+
473+ # Check rewrite approach.
474+ tracer = MatchingTracer ()
475+ count = _basic_rules .reshape_reshape_rule .apply_to_model (model , tracer = tracer )
476+ self .assertEqual (count , 0 )
477+
478+ # Check that the error message is the expected one
479+ tracer_match = tracer .best_matches_map [_basic_rules .reshape_reshape_rule ][0 ]
480+ self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
481+ self .assertRegex (tracer_match .match_result .reason , error_msg )
482+
483+
468484if __name__ == "__main__" :
469485 unittest .main (verbosity = 2 )
0 commit comments