Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7292,12 +7292,108 @@
return op.Tile(self_expanded, repeats)


def aten_repeat_interleave(
repeats: TensorType, output_size: Optional[int] = None
@torch_op("aten::repeat_interleave.Scalar", trace_only=True)
Comment thread
xadupre marked this conversation as resolved.
Outdated
def aten_repeat_interleave_int(
self: TensorType, repeats: int, dim: Optional[int]
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor
Comment thread
xadupre marked this conversation as resolved.
Outdated

raise NotImplementedError()
The trick is to repeat in one direction orthogonal to reshape.

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)

is equivalent to:

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat((1, 2)).reshape((-1, t.shape[1]))
"""
if dim is None:
raise NotImplementedError("No conversion available yet when dim is None.")

self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
onehot = op.Concat(op.ConstantOfShape((self_rank,), value=[1]), repeats, axis=0)
Comment thread
xadupre marked this conversation as resolved.
Outdated
tiled = op.Tile(unsqueezed, onehot)
Comment thread
xadupre marked this conversation as resolved.
Outdated

if dim < -1:
Comment thread
xadupre marked this conversation as resolved.
Outdated
dim += self_rank
return aten_flatten(tiled, -2 if dim == -1 else dim, -1 if dim == -1 else (dim + 1))
Comment thread
xadupre marked this conversation as resolved.
Outdated


@torch_op("aten::repeat_interleave.Tensor", trace_only=True)
def aten_repeat_interleave_Tensor(
self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor

When `repeats` is a tensor, each line is multiplied
by a different number.
There are multiple strategies. Here is one.

.. code-block:: python

import torch

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
times = torch.tensor([2, 3], dtype=torch.int64)
y = x.repeat_interleave(times, dim=0)
print("repeat_interleave")
print(y)

ci = times.cumsum(dim=0)
rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1))
srows = times.shape[0] - rows.to(torch.int64).sum(axis=0)
indices = srows.reshape((-1, ))
print("decomposed")
print(x[indices, :])
"""
if repeats is None:
repeats = self
self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1)
if dim is None:
# flatten
self = op.Reshape(self, [-1])
rk = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rk = 1
rank = 1

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
rk = len(self.shape)

if rk > 2:
shape_x0 = op.Shape(self, start=0, end=1)
shape_x = op.Shape(self, start=1)
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
elif rk == 1:
shape_x0 = None
Comment thread Fixed
shape_x = None
self = op.Reshape(self, [-1, 1])
else:
if rk != 2:
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
Comment thread
xadupre marked this conversation as resolved.
shape_x = None

ci = op.CumSum(repeats, [0])
last_ci = op.Gather(ci, [-1])
trange = op.Range(0, op.Squeeze(last_ci, [0]), 1)
rows = op.Less(trange, op.Unsqueeze(ci, [-1]))
srows = op.Sub(
op.Shape(self, start=0, end=1),
op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]),
)
indices = op.Reshape(srows, [-1])
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
if rk == 2:
return values
# shape_x cannot be None at this stage.
assert shape_x is not None # for mypy
return op.Reshape(
values,
op.Concat([-1], shape_x, axis=0) if shape_x else [-1],
Comment thread Fixed
Comment thread Fixed
)


@torch_op("aten::reshape")
Expand Down
42 changes: 42 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_integer(self):
Comment thread
xadupre marked this conversation as resolved.
Outdated
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 3, dim=1)

onnx_program = torch.onnx.export(
Model(), (torch.randn(2, 3),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_tensor(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind, dim=0)

onnx_program = torch.onnx.export(
Model(),
(
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
torch.tensor([1, 2], dtype=torch.int64),
),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_tensor_none(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind)

onnx_program = torch.onnx.export(
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning test

This assignment to 'onnx_program' is unnecessary as it is
redefined
before this value is used.
Model(),
(
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
torch.tensor([1, 2, 3, 2], dtype=torch.int64),
),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,9 @@ def _where_input_wrangler(
core_ops.aten_remainder,
),
TorchLibOpInfo("repeat", core_ops.aten_repeat),
# needs to split into two cases
# TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Scalar),
Comment thread
xadupre marked this conversation as resolved.
Outdated
# TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor),
TorchLibOpInfo("reshape", core_ops.aten_reshape),
TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj),
TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),
Expand Down
Loading