|
@torch_op("aten::gather", trace_only=True) |
|
def aten_gather( |
|
self: TReal, |
|
dim: int, |
|
index: TInt, |
|
sparse_grad: bool = False, |
|
) -> TReal: |
|
"""gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor""" |
|
|
|
if len(self.shape) == 0: |
|
if len(index.shape) == 0: |
|
return op.Identity(self) |
|
else: |
|
return op.Expand(self, op.Shape(index)) |
|
|
|
if len(index.shape) == 0: |
|
return op.Identity(self) |
|
|
|
index = op.Cast(index, to=INT64.dtype) |
|
result = op.GatherElements(self, index, axis=dim) |
|
return result |
When a scalar index is provided to aten::gather, the following behavior occurs:
>>> import torch
>>> x = torch.arange(3)
>>> index = torch.tensor(1)
>>> out = torch.ops.aten.gather.default(x, 0, index)
>>> out.shape
torch.Size([])
>>> out
tensor(1)
Therefore, when a 0-dimensional tensor is passed, we'll need to implement branching logic for squeeze/unsqueeze operations.
is_scalar_index = len(index.shape) == 0
if is_scalar_index:
index = op.Unsqueeze(index, [0])
index = op.Cast(index, to=INT64.dtype)
result = op.GatherElements(self, index, axis=dim)
if is_scalar_index:
result = op.Squeeze(result, [0])
onnxscript/onnxscript/function_libs/torch_lib/ops/core.py
Lines 3802 to 3822 in ea79022
When a scalar index is provided to aten::gather, the following behavior occurs:
Therefore, when a 0-dimensional tensor is passed, we'll need to implement branching logic for squeeze/unsqueeze operations.