Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 106 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor:
return op.Tile(self_expanded, repeats)


def aten_repeat_interleave(
repeats: TensorType, output_size: Optional[int] = None
@torch_op("aten::repeat_interleave.self_int", trace_only=True)
def aten_repeat_interleave_self_int(
self: TensorType, repeats: int, dim: Optional[int] = None
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor

raise NotImplementedError()
The trick is to repeat in one direction orthogonal to reshape.

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)

is equivalent to:

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat((1, 2)).reshape((-1, t.shape[1]))
"""
if dim is None:
raise NotImplementedError("No conversion available yet when dim is None.")

self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
Comment thread
xadupre marked this conversation as resolved.
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
tiled = op.Tile(unsqueezed, tile_repeat)
if self_rank == 1:
return op.Identity(tiled)
final_shape = op.Concat(
op.Shape(self, start=0, end=dim),
op.Constant(value_ints=[-1]),
op.Shape(self, start=dim + 1),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using pos_dim instead of dim ... otherwise, dim+1 can cause problems when dim == -1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to test-cases for negative dim, including -1.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

axis=0,
)
return op.Reshape(tiled, final_shape)


@torch_op("aten::repeat_interleave.Tensor", trace_only=True)
def aten_repeat_interleave_Tensor(
self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor

When `repeats` is a tensor, each line is multiplied
by a different number.
There are multiple strategies. Here is one.

.. code-block:: python

import torch

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
times = torch.tensor([2, 3], dtype=torch.int64)
y = x.repeat_interleave(times, dim=0)
print("repeat_interleave")
print(y)

ci = times.cumsum(dim=0)
rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1))
srows = times.shape[0] - rows.to(torch.int64).sum(axis=0)
indices = srows.reshape((-1, ))
print("decomposed")
print(x[indices, :])
"""
if repeats is None:
repeats = self
self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1)
if dim is None:
# flatten
self = op.Reshape(self, [-1])
rk = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rk = 1
rank = 1

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
rk = len(self.shape)

if rk > 2:
shape_x0 = op.Shape(self, start=0, end=1)
shape_x = op.Shape(self, start=1)
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
elif rk == 1:
shape_x = None
self = op.Reshape(self, [-1, 1])
else:
if rk != 2:
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
Comment thread
xadupre marked this conversation as resolved.
shape_x = None

ci = op.CumSum(repeats, [0])
last_ci = op.Gather(ci, [-1])
trange = op.Range(0, op.Squeeze(last_ci, [0]), 1)
rows = op.Less(trange, op.Unsqueeze(ci, [-1]))
srows = op.Sub(
op.Shape(self, start=0, end=1),
op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]),
)
indices = op.Reshape(srows, [-1])
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
if rk == 2:
return values
# shape_x is None at this stage.
assert shape_x is None # for mypy
return op.Reshape(
values,
op.Concat([-1], shape_x, axis=0) if shape_x else [-1],
Comment thread Fixed
Comment thread Fixed
)


@torch_op("aten::reshape")
Expand Down
61 changes: 61 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,67 @@
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 3, dim=1)

onnx_program = torch.onnx.export(
Model(), (torch.randn(2, 3),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_integer_2(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 3, dim=1)

onnx_program = torch.onnx.export(
Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_tensor(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind, dim=0)

onnx_program = torch.onnx.export(
Model(),
(
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
torch.tensor([1, 2], dtype=torch.int64),
),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_tensor_none(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind)

inputs = (
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
torch.tensor([1, 2, 3, 2], dtype=torch.int64),
)
onnx_program = torch.onnx.export(
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning test

This assignment to 'onnx_program' is unnecessary as it is
redefined
before this value is used.
Model(),
inputs,
dynamo=True,
optimize=False,
)
onnx_program = torch.onnx.export(
Model(),
inputs,
input_names=["x", "ind"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_sdpa_with_bool_attn_mask(self):
class ScaledDotProductAttention(torch.nn.Module):
def forward(self, query, key, value, attn_mask):
Expand Down
34 changes: 34 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,40 @@ def _where_input_wrangler(
core_ops.aten_remainder,
),
TorchLibOpInfo("repeat", core_ops.aten_repeat),
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int)
.skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int),
reason=("ignore cases when repeasts is a Tensor"),
)
.skip(
dtypes=(torch.bool,),
reason="bool not supported",
)
.skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="fixme: conversion not implemented if dim is None",
)
.skip(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: conversion not implemented when input tensor is empty",
),
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor)
.skip(
matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int),
reason=("ignore cases when repeasts is an int"),
)
.skip(
dtypes=(torch.bool,),
reason="bool not supported",
)
.skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="fixme: conversion not implemented if dim is None",
)
.skip(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: conversion not implemented when input tensor is empty",
),
TorchLibOpInfo("reshape", core_ops.aten_reshape),
TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj),
TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),
Expand Down
Loading