File tree Expand file tree Collapse file tree 1 file changed +19
-5
lines changed
Expand file tree Collapse file tree 1 file changed +19
-5
lines changed Original file line number Diff line number Diff line change @@ -899,12 +899,26 @@ def num_outputs(self) -> int:
899899 return len (self ._outputs )
900900
901901 def commute (self ) -> Sequence [GraphPattern ]:
902+ # List all commutative elementwise (binary) operators for which we
903+ # consider swapping the inputs
904+ COMMUTATIVE_OPS = {
905+ ("" , "Add" , "" ),
906+ ("" , "Mul" , "" ),
907+ ("" , "And" , "" ),
908+ ("" , "Or" , "" ),
909+ ("" , "Xor" , "" ),
910+ ("" , "BitwiseAnd" , "" ),
911+ ("" , "BitwiseOr" , "" ),
912+ ("" , "BitwiseXor" , "" ),
913+ ("" , "Equal" , "" ),
914+ ("" , "Max" , "" ),
915+ ("" , "Mean" , "" ),
916+ ("" , "Min" , "" ),
917+ ("" , "Sum" , "" ),
918+ }
919+
902920 def commute_node (node : NodePattern ) -> Iterable [bool ]:
903- if node .op_identifier () == ("" , "Add" , "" ) or node .op_identifier () == (
904- "" ,
905- "Mul" ,
906- "" ,
907- ):
921+ if node .op_identifier () in COMMUTATIVE_OPS :
908922 # Try with and without swapping inputs.
909923 return [False , True ]
910924 # No swapping of inputs
You can’t perform that action at this time.
0 commit comments