Skip to content

[torchlib] Add missing dtype parameter to aten_mean_dim#2885

Merged
justinchuby merged 2 commits intomicrosoft:mainfrom
linusjuni:fix/aten-mean-dim-dtype
Apr 10, 2026
Merged

[torchlib] Add missing dtype parameter to aten_mean_dim#2885
justinchuby merged 2 commits intomicrosoft:mainfrom
linusjuni:fix/aten-mean-dim-dtype

Conversation

@linusjuni
Copy link
Copy Markdown
Contributor

Fixes #2884

aten_mean_dim and aten_mean_dim_complex are missing the dtype keyword argument from their signatures, even though the ATen schema documents it (ScalarType? dtype=None). This causes a TypeError when PyTorch lowers aten::mean.dim with an explicit dtype - which happens for any model using GlobalAveragePooling2D (Keras/PyTorch).

  • Add dtype: int = -1 to aten_mean_dim, with op.Cast when dtype is specified
  • Add dtype: int = -1 to aten_mean_dim_complex, raising NotImplementedError for complex tensors

Follows the same pattern used by aten_sum_dim_IntList and aten_sum_dim_IntList_complex.

The ATen schema for mean.dim documents dtype as an optional parameter,
but aten_mean_dim and aten_mean_dim_complex did not accept it. This
causes a TypeError when PyTorch lowers mean.dim with an explicit dtype
(e.g. from GlobalAveragePooling2D in Keras).

Add dtype: int = -1 to both functions, following the same pattern used
by aten_sum_dim_IntList.

Fixes microsoft#2884
@linusjuni
Copy link
Copy Markdown
Contributor Author

Hey! This is my first contribution to onnxscript. We ran into this while exporting Keras models to ONNX at work - GlobalAveragePooling2D lowers through aten::mean.dim with an explicit dtype, which hits the missing parameter. Happy to adjust anything if needed😄

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Apr 10, 2026
@justinchuby justinchuby enabled auto-merge (squash) April 10, 2026 15:16
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 10, 2026

Codecov Report

❌ Patch coverage is 33.33333% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.03%. Comparing base (1ef0ec9) to head (abcd2b6).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 33.33% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2885      +/-   ##
==========================================
- Coverage   72.04%   72.03%   -0.01%     
==========================================
  Files         239      239              
  Lines       29305    29309       +4     
  Branches     2880     2882       +2     
==========================================
  Hits        21112    21112              
- Misses       7216     7218       +2     
- Partials      977      979       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby disabled auto-merge April 10, 2026 15:39
@justinchuby
Copy link
Copy Markdown
Collaborator

@linusjuni could you fix lint?

@linusjuni
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree

@justinchuby justinchuby enabled auto-merge (squash) April 10, 2026 15:55
@justinchuby justinchuby merged commit 12234f8 into microsoft:main Apr 10, 2026
28 of 32 checks passed
justinchuby pushed a commit that referenced this pull request Apr 17, 2026
Fixes #2884

`aten_mean_dim` and `aten_mean_dim_complex` are missing the `dtype`
keyword argument from their signatures, even though the ATen schema
documents it (`ScalarType? dtype=None`). This causes a `TypeError` when
PyTorch lowers `aten::mean.dim` with an explicit `dtype` - which happens
for any model using `GlobalAveragePooling2D` (Keras/PyTorch).

- Add `dtype: int = -1` to `aten_mean_dim`, with `op.Cast` when dtype is
specified
- Add `dtype: int = -1` to `aten_mean_dim_complex`, raising
`NotImplementedError` for complex tensors

Follows the same pattern used by `aten_sum_dim_IntList` and
`aten_sum_dim_IntList_complex`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

aten_mean_dim missing dtype parameter

2 participants