Skip to content

Commit 555f81e

Browse files
authored
[Rewriter] Extend list of supported commutative operations (#2741)
Signed-off-by: Christoph Berganski <christoph.berganski@gmail.com>
1 parent 3f197a2 commit 555f81e

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

onnxscript/rewriter/_pattern_ir.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -899,12 +899,26 @@ def num_outputs(self) -> int:
899899
return len(self._outputs)
900900

901901
def commute(self) -> Sequence[GraphPattern]:
902+
# List all commutative elementwise (binary) operators for which we
903+
# consider swapping the inputs
904+
COMMUTATIVE_OPS = {
905+
("", "Add", ""),
906+
("", "Mul", ""),
907+
("", "And", ""),
908+
("", "Or", ""),
909+
("", "Xor", ""),
910+
("", "BitwiseAnd", ""),
911+
("", "BitwiseOr", ""),
912+
("", "BitwiseXor", ""),
913+
("", "Equal", ""),
914+
("", "Max", ""),
915+
("", "Mean", ""),
916+
("", "Min", ""),
917+
("", "Sum", ""),
918+
}
919+
902920
def commute_node(node: NodePattern) -> Iterable[bool]:
903-
if node.op_identifier() == ("", "Add", "") or node.op_identifier() == (
904-
"",
905-
"Mul",
906-
"",
907-
):
921+
if node.op_identifier() in COMMUTATIVE_OPS:
908922
# Try with and without swapping inputs.
909923
return [False, True]
910924
# No swapping of inputs

0 commit comments

Comments
 (0)