-
Notifications
You must be signed in to change notification settings - Fork 109
Optimize aten::min/max.dim with TopK op #2780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
d15e6ac
39baa38
312ef9e
dde8d14
72fe1db
301635e
81fa713
264aed2
09eefce
1c7a579
826a1f7
e84dc89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,272 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
|
|
||
| # Licensed under the MIT License. | ||
| """Fuses Reduce{Max,Min} and Arg{Max,Min} patterns into TopK. | ||
|
|
||
| Supported transformations: | ||
| - ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] | ||
| - ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] | ||
|
|
||
| Supports both ONNX opset versions: | ||
| - Opset 13-17: Reduce{Max,Min} with axes as an attribute | ||
| - Opset 18+: Reduce{Max,Min} with axes as a second input | ||
|
|
||
| Constraints: | ||
| - Both nodes must operate on the same input X. | ||
| - Both nodes must target the same axis. | ||
| - Both nodes must have the same keepdims attribute value. | ||
| - The Reduce node must operate on a single axis (len(axes) == 1). | ||
| - For opset 18+, the Reduce node's axes input must be a constant. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import abstractmethod | ||
|
|
||
| import numpy as np | ||
| import onnx_ir as ir | ||
|
|
||
| from onnxscript.rewriter._basics import MatchResult | ||
| from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet | ||
|
|
||
|
|
||
| class _FuseReduceArgToTopKBase(RewriteRuleClassBase): | ||
| """Base class for fusing Reduce{Max,Min} + Arg{Max,Min} into TopK. | ||
|
|
||
| This base class contains the common logic for checking and rewriting patterns where | ||
| a Reduce operation and its corresponding Arg operation can be replaced with a single | ||
| TopK operation. | ||
|
|
||
| Subclasses must implement: | ||
| - pattern(): Define the specific Reduce and Arg operations to match | ||
| - largest: Property returning 1 for Max operations, 0 for Min operations | ||
| """ | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def largest(self) -> int: | ||
| """Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements).""" | ||
|
|
||
| @staticmethod | ||
| def _normalize_axis(axis: int, rank: int | None) -> int: | ||
| """Normalize a potentially negative axis to a positive axis index. | ||
|
|
||
| Args: | ||
| axis: The axis to normalize (can be negative). | ||
| rank: The rank of the tensor, or None if unknown. | ||
|
|
||
| Returns: | ||
| The normalized axis (non-negative if rank is known and axis was negative). | ||
| """ | ||
| if rank is not None and axis < 0: | ||
| return axis + rank | ||
| return axis | ||
|
|
||
| def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: | ||
| """Check if Reduce and Arg operations can be safely fused into TopK. | ||
|
|
||
| Conditions: | ||
| - Both nodes must have the same keepdims attribute. | ||
| - The Reduce node must operate on a single axis. | ||
| - Both nodes must operate on the same axis. | ||
| - The Arg node must not use select_last_index=1 (TopK doesn't support this). | ||
|
|
||
| Args: | ||
| context: The rewrite context (unused). | ||
| reduce_val: The output of the Reduce operation (ReduceMax/ReduceMin). | ||
| arg_idx: The output of the Arg operation (ArgMax/ArgMin). | ||
|
|
||
| Returns: | ||
| MatchResult: Success if the pattern can be fused, Failure otherwise. | ||
| """ | ||
| del context | ||
| check_result = MatchResult() | ||
|
|
||
| reduce_node = reduce_val.producer() | ||
| arg_node = arg_idx.producer() | ||
|
|
||
| # Step 1: Get keepdims attribute from both nodes | ||
| reduce_keepdims_attr = reduce_node.attributes.get("keepdims") | ||
| arg_keepdims_attr = arg_node.attributes.get("keepdims") | ||
|
|
||
| # ONNX default: keepdims = 1 for both Reduce and Arg operations | ||
| reduce_keepdims = ( | ||
| reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 | ||
|
titaiwangms marked this conversation as resolved.
Outdated
|
||
| ) | ||
| arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1 | ||
|
titaiwangms marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Step 2: Check if keepdims match | ||
| if reduce_keepdims != arg_keepdims: | ||
| return check_result.fail( | ||
| f"keepdims mismatch: {reduce_node.op_type} has {reduce_keepdims}, " | ||
| f"{arg_node.op_type} has {arg_keepdims}." | ||
| ) | ||
|
|
||
| # Step 3: Get axes from Reduce operation | ||
| # In opset 18+, axes is an input; in opset 13-17, it's an attribute | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @justinchuby What do you think?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only opset 18+ is fine
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry @danielhumanmod Can you add a NOTE/comment somewhere that says the rule is only for opset 18+. Since now it's not for default rewrite rules, it could be used standalone for other users. |
||
| reduce_axes_attr = reduce_node.attributes.get("axes") | ||
|
|
||
| if reduce_axes_attr is not None: | ||
| # Opset 13-17: axes is an attribute | ||
| try: | ||
| axes_list = list(reduce_axes_attr.as_ints()) | ||
| except Exception: | ||
| return check_result.fail(f"Cannot parse {reduce_node.op_type} axes attribute.") | ||
| elif len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: | ||
| # Opset 18+: axes is the second input | ||
| axes_input = reduce_node.inputs[1] | ||
| axes_tensor = ir.convenience.get_const_tensor(axes_input) | ||
| if axes_tensor is None: | ||
| return check_result.fail( | ||
| f"{reduce_node.op_type} axes input is not a constant." | ||
| ) | ||
| try: | ||
| axes_array = axes_tensor.numpy() | ||
| axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)] | ||
| except Exception: | ||
| return check_result.fail(f"Cannot parse {reduce_node.op_type} axes input.") | ||
| else: | ||
| return check_result.fail( | ||
| f"{reduce_node.op_type} axes not found (neither attribute nor input)." | ||
| ) | ||
|
|
||
| # Step 4: Check that Reduce operates on exactly one axis | ||
| if len(axes_list) != 1: | ||
| return check_result.fail( | ||
| f"{reduce_node.op_type} must operate on a single axis, got {len(axes_list)} axes." | ||
| ) | ||
|
|
||
| reduce_axis = axes_list[0] | ||
|
|
||
| # Step 5: Get axis from Arg operation | ||
| # ONNX default: axis = 0 for ArgMax/ArgMin | ||
| arg_axis_attr = arg_node.attributes.get("axis") | ||
| arg_axis = arg_axis_attr.as_int() if arg_axis_attr is not None else 0 | ||
|
|
||
| # Step 6: Check select_last_index attribute (if present) | ||
| # TopK always returns the first occurrence in case of ties | ||
| select_last_index_attr = arg_node.attributes.get("select_last_index") | ||
| if select_last_index_attr is not None and select_last_index_attr.as_int() != 0: | ||
| return check_result.fail( | ||
| f"{arg_node.op_type} has select_last_index=1, which is not supported by TopK." | ||
| ) | ||
|
|
||
| # Step 7: Normalize axes if rank is known (handle negative indices) | ||
| input_x = reduce_node.inputs[0] | ||
| rank = len(input_x.shape) if input_x.shape is not None else None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if symbolic shape could work on this case? @justinchuby
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh actually it is a very good catch, I will use |
||
|
|
||
| if self._normalize_axis(reduce_axis, rank) != self._normalize_axis(arg_axis, rank): | ||
| return check_result.fail( | ||
| f"Axis mismatch: {reduce_node.op_type} operates on axis {reduce_axis}, " | ||
| f"{arg_node.op_type} operates on axis {arg_axis}." | ||
| ) | ||
|
|
||
| return check_result | ||
|
|
||
| def rewrite(self, op, x, reduce_val, arg_idx): | ||
| """Rewrite the matched pattern with TopK (and optionally Squeeze). | ||
|
|
||
| Args: | ||
| op: The operation builder. | ||
| x: The input to both Reduce and Arg operations. | ||
| reduce_val: The output of the Reduce operation. | ||
| arg_idx: The output of the Arg operation. | ||
|
|
||
| Returns: | ||
| Tuple of (values, indices) matching the original outputs. | ||
| """ | ||
| # Step 1: Get the nodes | ||
| arg_node = arg_idx.producer() | ||
|
|
||
| # Step 2: Extract necessary attributes with ONNX default values | ||
| axis_attr = arg_node.attributes.get("axis") | ||
| keepdims_attr = arg_node.attributes.get("keepdims") | ||
|
|
||
| axis = axis_attr.as_int() if axis_attr is not None else 0 | ||
| keepdims = keepdims_attr.as_int() if keepdims_attr is not None else 1 | ||
|
|
||
| # Step 2b: Normalize axis (convert negative to positive) if rank is known | ||
| if axis < 0 and x.shape is not None: | ||
| axis = len(x.shape) + axis | ||
|
|
||
| # Step 3: Create K constant | ||
| k_constant = op.Constant(value=ir.tensor(np.array([1], dtype=np.int64))) | ||
|
|
||
| # Step 4: Create TopK node | ||
| topk_values, topk_indices = op.TopK( | ||
| x, | ||
| k_constant, | ||
| axis=axis, | ||
| largest=self.largest, | ||
| sorted=1, | ||
| _outputs=2, | ||
| ) | ||
|
|
||
| # Step 5: Handle keepdims=0 case | ||
| if keepdims == 0: | ||
| # TopK output always keeps the dimension (just makes it size 1) | ||
| # We need to squeeze it to match the original Reduce/Arg behavior | ||
| axes_constant = op.Constant(value=ir.tensor(np.array([axis], dtype=np.int64))) | ||
|
|
||
| new_values = op.Squeeze(topk_values, axes_constant) | ||
| new_indices = op.Squeeze(topk_indices, axes_constant) | ||
| else: | ||
| new_values = topk_values | ||
| new_indices = topk_indices | ||
|
|
||
| return new_values, new_indices | ||
|
|
||
|
|
||
| class FuseReduceMaxArgMaxToTopK(_FuseReduceArgToTopKBase): | ||
| """Replaces ReduceMax + ArgMax with TopK(largest=1). | ||
|
|
||
| Transformation: | ||
| ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) | ||
| → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] | ||
|
|
||
| When keepdims=0, the output of TopK is squeezed to match the original output shapes. | ||
| """ | ||
|
|
||
| @property | ||
| def largest(self) -> int: | ||
| return 1 # TopK returns largest elements | ||
|
|
||
| def pattern(self, op, x): | ||
| """Define the pattern to match: ReduceMax and ArgMax on the same input. | ||
|
|
||
| Note: For opset 18+, ReduceMax has a second input for axes, which we allow | ||
| but will validate in check() to ensure it's a constant. | ||
| """ | ||
| reduce_val = op.ReduceMax(x, _allow_other_inputs=True, _outputs=["reduce_val"]) | ||
| arg_idx = op.ArgMax(x, _outputs=["arg_idx"]) | ||
| return reduce_val, arg_idx | ||
|
|
||
|
|
||
| class FuseReduceMinArgMinToTopK(_FuseReduceArgToTopKBase): | ||
| """Replaces ReduceMin + ArgMin with TopK(largest=0). | ||
|
|
||
| Transformation: | ||
| ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) | ||
| → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] | ||
|
|
||
| When keepdims=0, the output of TopK is squeezed to match the original output shapes. | ||
| """ | ||
|
|
||
| @property | ||
| def largest(self) -> int: | ||
| return 0 # TopK returns smallest elements | ||
|
|
||
| def pattern(self, op, x): | ||
| """Define the pattern to match: ReduceMin and ArgMin on the same input. | ||
|
|
||
| Note: For opset 18+, ReduceMin has a second input for axes, which we allow | ||
| but will validate in check() to ensure it's a constant. | ||
| """ | ||
| reduce_val = op.ReduceMin(x, _allow_other_inputs=True, _outputs=["reduce_val"]) | ||
| arg_idx = op.ArgMin(x, _outputs=["arg_idx"]) | ||
| return reduce_val, arg_idx | ||
|
|
||
|
|
||
| reduce_max_argmax_to_topk_rule = FuseReduceMaxArgMaxToTopK().rule() | ||
| reduce_min_argmin_to_topk_rule = FuseReduceMinArgMinToTopK().rule() | ||
|
|
||
| rules = RewriteRuleSet([reduce_max_argmax_to_topk_rule, reduce_min_argmin_to_topk_rule]) | ||
|
danielhumanmod marked this conversation as resolved.
Outdated
justinchuby marked this conversation as resolved.
Outdated
|
||
Uh oh!
There was an error while loading. Please reload this page.