diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6698a2ccdb..95fbe39811 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3814,11 +3814,15 @@ def aten_gather( else: return op.Expand(self, op.Shape(index)) - if len(index.shape) == 0: - return op.Identity(self) + is_scalar_index = len(index.shape) == 0 + if is_scalar_index: + index = op.Unsqueeze(index, [0]) index = op.Cast(index, to=INT64.dtype) result = op.GatherElements(self, index, axis=dim) + + if is_scalar_index: + result = op.Squeeze(result, [0]) return result