Skip to content

Commit b66859b

Browse files
committed
[torchlib] Add missing dtype parameter to aten_mean_dim
The ATen schema for mean.dim documents dtype as an optional parameter, but aten_mean_dim and aten_mean_dim_complex did not accept it. This causes a TypeError when PyTorch lowers mean.dim with an explicit dtype (e.g. from GlobalAveragePooling2D in Keras). Add dtype: int = -1 to both functions, following the same pattern used by aten_sum_dim_IntList. Fixes #2884
1 parent 1ef0ec9 commit b66859b

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6231,19 +6231,23 @@ def aten_mean_complex(self: TReal) -> TReal:
62316231

62326232

62336233
@torch_op("aten::mean.dim", trace_only=True)
6234-
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6234+
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1) -> TReal:
62356235
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
62366236

62376237
if len(self.shape) == 0:
62386238
result = self
62396239
else:
62406240
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
62416241
result = op.ReduceMean(self, dims, keepdims=keepdim)
6242+
6243+
if dtype != -1 and dtype is not None:
6244+
result = op.Cast(result, to=dtype)
6245+
62426246
return result
62436247

62446248

62456249
@torch_op("aten::mean.dim", trace_only=True, complex=True)
6246-
def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6250+
def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1) -> TReal:
62476251
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
62486252

62496253
if len(self.shape) == 1:
@@ -6254,6 +6258,12 @@ def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TRe
62546258
dim = op.Where(op.Less(dim, zero), op.Sub(dim, one), dim)
62556259
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
62566260
result = op.ReduceMean(self, dims, keepdims=keepdim)
6261+
6262+
if dtype != -1 and dtype is not None:
6263+
raise NotImplementedError(
6264+
"support for the dtype argument is not implemented for complex tensors"
6265+
)
6266+
62576267
return result
62586268

62596269

0 commit comments

Comments
 (0)