Skip to content
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_collapse_slices,
_fuse_batchnorm,
_fuse_pad_into_conv,
_fuse_reduce_arg_to_topk,
_fuse_relus_clips,
_min_max_to_clip,
_no_op,
Expand All @@ -61,6 +62,7 @@
*_fuse_pad_into_conv.rules,
*_fuse_batchnorm.rules,
*_remove_optional_bias.rules,
*_fuse_reduce_arg_to_topk.rules,
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
)


Expand Down
272 changes: 272 additions & 0 deletions onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.
"""Fuses Reduce{Max,Min} and Arg{Max,Min} patterns into TopK.

Supported transformations:
- ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0]
- ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0]

Supports both ONNX opset versions:
- Opset 13-17: Reduce{Max,Min} with axes as an attribute
- Opset 18+: Reduce{Max,Min} with axes as a second input

Constraints:
- Both nodes must operate on the same input X.
- Both nodes must target the same axis.
- Both nodes must have the same keepdims attribute value.
- The Reduce node must operate on a single axis (len(axes) == 1).
- For opset 18+, the Reduce node's axes input must be a constant.
"""

from __future__ import annotations

from abc import abstractmethod

import numpy as np
import onnx_ir as ir

from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


class _FuseReduceArgToTopKBase(RewriteRuleClassBase):
"""Base class for fusing Reduce{Max,Min} + Arg{Max,Min} into TopK.

This base class contains the common logic for checking and rewriting patterns where
a Reduce operation and its corresponding Arg operation can be replaced with a single
TopK operation.

Subclasses must implement:
- pattern(): Define the specific Reduce and Arg operations to match
- largest: Property returning 1 for Max operations, 0 for Min operations
"""

@property
@abstractmethod
def largest(self) -> int:
"""Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements)."""

@staticmethod
def _normalize_axis(axis: int, rank: int | None) -> int:
"""Normalize a potentially negative axis to a positive axis index.

Args:
axis: The axis to normalize (can be negative).
rank: The rank of the tensor, or None if unknown.

Returns:
The normalized axis (non-negative if rank is known and axis was negative).
"""
if rank is not None and axis < 0:
return axis + rank
return axis

def check(self, context, reduce_val, arg_idx, **_) -> MatchResult:
"""Check if Reduce and Arg operations can be safely fused into TopK.

Conditions:
- Both nodes must have the same keepdims attribute.
- The Reduce node must operate on a single axis.
- Both nodes must operate on the same axis.
- The Arg node must not use select_last_index=1 (TopK doesn't support this).

Args:
context: The rewrite context (unused).
reduce_val: The output of the Reduce operation (ReduceMax/ReduceMin).
arg_idx: The output of the Arg operation (ArgMax/ArgMin).

Returns:
MatchResult: Success if the pattern can be fused, Failure otherwise.
"""
del context
check_result = MatchResult()

reduce_node = reduce_val.producer()
arg_node = arg_idx.producer()

# Step 1: Get keepdims attribute from both nodes
reduce_keepdims_attr = reduce_node.attributes.get("keepdims")
arg_keepdims_attr = arg_node.attributes.get("keepdims")

# ONNX default: keepdims = 1 for both Reduce and Arg operations
reduce_keepdims = (
reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
)
arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1
Comment thread
titaiwangms marked this conversation as resolved.
Outdated

# Step 2: Check if keepdims match
if reduce_keepdims != arg_keepdims:
return check_result.fail(
f"keepdims mismatch: {reduce_node.op_type} has {reduce_keepdims}, "
f"{arg_node.op_type} has {arg_keepdims}."
)

# Step 3: Get axes from Reduce operation
# In opset 18+, axes is an input; in opset 13-17, it's an attribute
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only opset 18+ is fine

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @danielhumanmod Can you add a NOTE/comment somewhere that says the rule is only for opset 18+. Since now it's not for default rewrite rules, it could be used standalone for other users.

reduce_axes_attr = reduce_node.attributes.get("axes")

if reduce_axes_attr is not None:
# Opset 13-17: axes is an attribute
try:
axes_list = list(reduce_axes_attr.as_ints())
except Exception:
return check_result.fail(f"Cannot parse {reduce_node.op_type} axes attribute.")
elif len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None:
# Opset 18+: axes is the second input
axes_input = reduce_node.inputs[1]
axes_tensor = ir.convenience.get_const_tensor(axes_input)
if axes_tensor is None:
return check_result.fail(
f"{reduce_node.op_type} axes input is not a constant."
)
try:
axes_array = axes_tensor.numpy()
axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)]
except Exception:
return check_result.fail(f"Cannot parse {reduce_node.op_type} axes input.")
else:
return check_result.fail(
f"{reduce_node.op_type} axes not found (neither attribute nor input)."
)

# Step 4: Check that Reduce operates on exactly one axis
if len(axes_list) != 1:
return check_result.fail(
f"{reduce_node.op_type} must operate on a single axis, got {len(axes_list)} axes."
)

reduce_axis = axes_list[0]

# Step 5: Get axis from Arg operation
# ONNX default: axis = 0 for ArgMax/ArgMin
arg_axis_attr = arg_node.attributes.get("axis")
arg_axis = arg_axis_attr.as_int() if arg_axis_attr is not None else 0

# Step 6: Check select_last_index attribute (if present)
# TopK always returns the first occurrence in case of ties
select_last_index_attr = arg_node.attributes.get("select_last_index")
if select_last_index_attr is not None and select_last_index_attr.as_int() != 0:
return check_result.fail(
f"{arg_node.op_type} has select_last_index=1, which is not supported by TopK."
)

# Step 7: Normalize axes if rank is known (handle negative indices)
input_x = reduce_node.inputs[0]
rank = len(input_x.shape) if input_x.shape is not None else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if symbolic shape could work on this case? @justinchuby

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh actually it is a very good catch, I will use shape.rank() instead to ensure to support both static and symbolic shape, thanks a bunch!


if self._normalize_axis(reduce_axis, rank) != self._normalize_axis(arg_axis, rank):
return check_result.fail(
f"Axis mismatch: {reduce_node.op_type} operates on axis {reduce_axis}, "
f"{arg_node.op_type} operates on axis {arg_axis}."
)

return check_result

def rewrite(self, op, x, reduce_val, arg_idx):
"""Rewrite the matched pattern with TopK (and optionally Squeeze).

Args:
op: The operation builder.
x: The input to both Reduce and Arg operations.
reduce_val: The output of the Reduce operation.
arg_idx: The output of the Arg operation.

Returns:
Tuple of (values, indices) matching the original outputs.
"""
# Step 1: Get the nodes
arg_node = arg_idx.producer()

# Step 2: Extract necessary attributes with ONNX default values
axis_attr = arg_node.attributes.get("axis")
keepdims_attr = arg_node.attributes.get("keepdims")

axis = axis_attr.as_int() if axis_attr is not None else 0
keepdims = keepdims_attr.as_int() if keepdims_attr is not None else 1

# Step 2b: Normalize axis (convert negative to positive) if rank is known
if axis < 0 and x.shape is not None:
axis = len(x.shape) + axis

# Step 3: Create K constant
k_constant = op.Constant(value=ir.tensor(np.array([1], dtype=np.int64)))

# Step 4: Create TopK node
topk_values, topk_indices = op.TopK(
x,
k_constant,
axis=axis,
largest=self.largest,
sorted=1,
_outputs=2,
)

# Step 5: Handle keepdims=0 case
if keepdims == 0:
# TopK output always keeps the dimension (just makes it size 1)
# We need to squeeze it to match the original Reduce/Arg behavior
axes_constant = op.Constant(value=ir.tensor(np.array([axis], dtype=np.int64)))

new_values = op.Squeeze(topk_values, axes_constant)
new_indices = op.Squeeze(topk_indices, axes_constant)
else:
new_values = topk_values
new_indices = topk_indices

return new_values, new_indices


class FuseReduceMaxArgMaxToTopK(_FuseReduceArgToTopKBase):
"""Replaces ReduceMax + ArgMax with TopK(largest=1).

Transformation:
ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k)
→ TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0]

When keepdims=0, the output of TopK is squeezed to match the original output shapes.
"""

@property
def largest(self) -> int:
return 1 # TopK returns largest elements

def pattern(self, op, x):
"""Define the pattern to match: ReduceMax and ArgMax on the same input.

Note: For opset 18+, ReduceMax has a second input for axes, which we allow
but will validate in check() to ensure it's a constant.
"""
reduce_val = op.ReduceMax(x, _allow_other_inputs=True, _outputs=["reduce_val"])
arg_idx = op.ArgMax(x, _outputs=["arg_idx"])
return reduce_val, arg_idx


class FuseReduceMinArgMinToTopK(_FuseReduceArgToTopKBase):
"""Replaces ReduceMin + ArgMin with TopK(largest=0).

Transformation:
ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k)
→ TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0]

When keepdims=0, the output of TopK is squeezed to match the original output shapes.
"""

@property
def largest(self) -> int:
return 0 # TopK returns smallest elements

def pattern(self, op, x):
"""Define the pattern to match: ReduceMin and ArgMin on the same input.

Note: For opset 18+, ReduceMin has a second input for axes, which we allow
but will validate in check() to ensure it's a constant.
"""
reduce_val = op.ReduceMin(x, _allow_other_inputs=True, _outputs=["reduce_val"])
arg_idx = op.ArgMin(x, _outputs=["arg_idx"])
return reduce_val, arg_idx


reduce_max_argmax_to_topk_rule = FuseReduceMaxArgMaxToTopK().rule()
reduce_min_argmin_to_topk_rule = FuseReduceMinArgMinToTopK().rule()

rules = RewriteRuleSet([reduce_max_argmax_to_topk_rule, reduce_min_argmin_to_topk_rule])
Comment thread
danielhumanmod marked this conversation as resolved.
Outdated
Comment thread
justinchuby marked this conversation as resolved.
Outdated
Loading
Loading