@@ -281,6 +281,46 @@ def test_full_optimization(self):
281281 self .assertEqual (count , 3 )
282282 self .assertEqual (len (model .graph ), 5 )
283283
284+ def test_full_optimization_more_complex (self ):
285+ import onnx .helper as oh
286+ import onnx .numpy_helper as onh
287+
288+ model_proto = oh .make_model (
289+ oh .make_graph (
290+ [
291+ oh .make_node ("Shape" , ["x" ], ["n" ], start = 0 , end = 1 ),
292+ oh .make_node ("Shape" , ["x" ], ["b" ], start = 1 , end = 2 ),
293+ oh .make_node ("Concat" , ["n" , "b" ], ["shape" ], axis = 0 ),
294+ oh .make_node ("Add" , ["shape" , "one" ], ["shape1" ]),
295+ oh .make_node ("Sub" , ["shape1" , "one" ], ["shape2" ]),
296+ oh .make_node ("Expand" , ["x" , "shape2" ], ["expanded" ]),
297+ oh .make_node ("Add" , ["expanded" , "y1" ], ["z1" ]),
298+ oh .make_node ("Add" , ["expanded" , "y2" ], ["z2" ]),
299+ oh .make_node ("Add" , ["expanded" , "y3" ], ["z3" ]),
300+ oh .make_node ("Add" , ["z1" , "z2" ], ["z12" ]),
301+ oh .make_node ("Add" , ["z12" , "z3" ], ["z" ]),
302+ ],
303+ "test" ,
304+ [
305+ oh .make_tensor_value_info ("x" , onnx .TensorProto .FLOAT , ["N" , 1 ]),
306+ oh .make_tensor_value_info ("y1" , onnx .TensorProto .FLOAT , [1 , "B" ]),
307+ oh .make_tensor_value_info ("y2" , onnx .TensorProto .FLOAT , [1 , "B" ]),
308+ oh .make_tensor_value_info ("y3" , onnx .TensorProto .FLOAT , [1 , "B" ]),
309+ ],
310+ [
311+ oh .make_tensor_value_info ("z" , onnx .TensorProto .FLOAT , ["N" , "B" ]),
312+ ],
313+ [onh .from_array (np .array ([1 ], dtype = np .int64 ), "one" )]
314+ ),
315+ ir_version = 11 ,
316+ opset_imports = [oh .make_opsetid ("" , 20 )],
317+ )
318+ onnx .checker .check_model (model_proto )
319+ model = ir .serde .deserialize_model (model_proto )
320+ count = mod .expand_before_binary_op_rules .apply_to_model (model )
321+ self .assertEqual (count , 3 )
322+ self .assertEqual (len (model .graph ), 5 )
323+
284324
285325if __name__ == "__main__" :
286326 unittest .main ()
0 commit comments