Skip to content

Commit ed31552

Browse files
committed
more complex case
1 parent 4ab5e34 commit ed31552

1 file changed

Lines changed: 40 additions & 0 deletions

File tree

onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,46 @@ def test_full_optimization(self):
281281
self.assertEqual(count, 3)
282282
self.assertEqual(len(model.graph), 5)
283283

284+
def test_full_optimization_more_complex(self):
285+
import onnx.helper as oh
286+
import onnx.numpy_helper as onh
287+
288+
model_proto = oh.make_model(
289+
oh.make_graph(
290+
[
291+
oh.make_node("Shape", ["x"], ["n"], start=0, end=1),
292+
oh.make_node("Shape", ["x"], ["b"], start=1, end=2),
293+
oh.make_node("Concat", ["n", "b"], ["shape"], axis=0),
294+
oh.make_node("Add", ["shape", "one"], ["shape1"]),
295+
oh.make_node("Sub", ["shape1", "one"], ["shape2"]),
296+
oh.make_node("Expand", ["x", "shape2"], ["expanded"]),
297+
oh.make_node("Add", ["expanded", "y1"], ["z1"]),
298+
oh.make_node("Add", ["expanded", "y2"], ["z2"]),
299+
oh.make_node("Add", ["expanded", "y3"], ["z3"]),
300+
oh.make_node("Add", ["z1", "z2"], ["z12"]),
301+
oh.make_node("Add", ["z12", "z3"], ["z"]),
302+
],
303+
"test",
304+
[
305+
oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, ["N", 1]),
306+
oh.make_tensor_value_info("y1", onnx.TensorProto.FLOAT, [1, "B"]),
307+
oh.make_tensor_value_info("y2", onnx.TensorProto.FLOAT, [1, "B"]),
308+
oh.make_tensor_value_info("y3", onnx.TensorProto.FLOAT, [1, "B"]),
309+
],
310+
[
311+
oh.make_tensor_value_info("z", onnx.TensorProto.FLOAT, ["N", "B"]),
312+
],
313+
[onh.from_array(np.array([1], dtype=np.int64), "one")]
314+
),
315+
ir_version=11,
316+
opset_imports=[oh.make_opsetid("", 20)],
317+
)
318+
onnx.checker.check_model(model_proto)
319+
model = ir.serde.deserialize_model(model_proto)
320+
count = mod.expand_before_binary_op_rules.apply_to_model(model)
321+
self.assertEqual(count, 3)
322+
self.assertEqual(len(model.graph), 5)
323+
284324

285325
if __name__ == "__main__":
286326
unittest.main()

0 commit comments

Comments
 (0)