Skip to content

Commit 6614c55

Browse files
Copilotjustinchuby
andcommitted
fix: correct scatter_reduce mean for include_self=False and lint
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 588f9b0 commit 6614c55

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8728,14 +8728,23 @@ def aten_scatter_reduce(
87288728
total_count = op.Add(
87298729
op.ConstantOfShape(op.Shape(self), value=one_val), scatter_count
87308730
)
8731+
result = op.Div(total_sum, total_count)
87318732
else:
8732-
total_sum = scatter_sum
8733-
# Avoid division by zero: where count == 0, sum is also 0, so 0/1 = 0 is correct
8734-
total_count = op.Max(
8733+
# For positions with scattered values: mean = sum / count
8734+
# For positions with no scattered values: preserve self[i] (include_self=False
8735+
# means the initial self value is not part of the reduction, but it is the
8736+
# output for positions with no incoming values)
8737+
safe_count = op.Max(
87358738
scatter_count,
87368739
op.ConstantOfShape(op.Shape(scatter_count), value=one_val),
87378740
)
8738-
result = op.Div(total_sum, total_count)
8741+
mean_vals = op.Div(scatter_sum, safe_count)
8742+
# Where count == 0, keep original self value; otherwise use computed mean
8743+
no_scatter = op.Equal(
8744+
scatter_count,
8745+
op.ConstantOfShape(op.Shape(scatter_count), value=zero_val),
8746+
)
8747+
result = op.Where(no_scatter, self, mean_vals)
87398748
if self_is_scalar:
87408749
result = op.Squeeze(result)
87418750
return result

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def test_scatter_reduce_mean_include_self_true(self):
6262
"""Test scatter_reduce with reduce='mean' and include_self=True."""
6363

6464
class ScatterMeanIncludeSelfModel(torch.nn.Module):
65-
def forward(self, x: torch.Tensor, index: torch.Tensor, src: torch.Tensor) -> torch.Tensor:
65+
def forward(
66+
self, x: torch.Tensor, index: torch.Tensor, src: torch.Tensor
67+
) -> torch.Tensor:
6668
x = x.clone()
6769
return x.scatter_reduce(0, index, src, reduce="mean", include_self=True)
6870

0 commit comments

Comments
 (0)