Skip to content

Commit 12234f8

Browse files
authored
[torchlib] Add missing dtype parameter to aten_mean_dim (#2885)
Fixes #2884 `aten_mean_dim` and `aten_mean_dim_complex` are missing the `dtype` keyword argument from their signatures, even though the ATen schema documents it (`ScalarType? dtype=None`). This causes a `TypeError` when PyTorch lowers `aten::mean.dim` with an explicit `dtype` - which happens for any model using `GlobalAveragePooling2D` (Keras/PyTorch). - Add `dtype: int = -1` to `aten_mean_dim`, with `op.Cast` when dtype is specified - Add `dtype: int = -1` to `aten_mean_dim_complex`, raising `NotImplementedError` for complex tensors Follows the same pattern used by `aten_sum_dim_IntList` and `aten_sum_dim_IntList_complex`.
1 parent 1ef0ec9 commit 12234f8

File tree

1 file changed

+14
-2
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+14
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6231,19 +6231,25 @@ def aten_mean_complex(self: TReal) -> TReal:
62316231

62326232

62336233
@torch_op("aten::mean.dim", trace_only=True)
6234-
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6234+
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1) -> TReal:
62356235
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
62366236

62376237
if len(self.shape) == 0:
62386238
result = self
62396239
else:
62406240
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
62416241
result = op.ReduceMean(self, dims, keepdims=keepdim)
6242+
6243+
if dtype != -1 and dtype is not None:
6244+
result = op.Cast(result, to=dtype)
6245+
62426246
return result
62436247

62446248

62456249
@torch_op("aten::mean.dim", trace_only=True, complex=True)
6246-
def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6250+
def aten_mean_dim_complex(
6251+
self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1
6252+
) -> TReal:
62476253
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
62486254

62496255
if len(self.shape) == 1:
@@ -6254,6 +6260,12 @@ def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TRe
62546260
dim = op.Where(op.Less(dim, zero), op.Sub(dim, one), dim)
62556261
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
62566262
result = op.ReduceMean(self, dims, keepdims=keepdim)
6263+
6264+
if dtype != -1 and dtype is not None:
6265+
raise NotImplementedError(
6266+
"support for the dtype argument is not implemented for complex tensors"
6267+
)
6268+
62576269
return result
62586270

62596271

0 commit comments

Comments
 (0)