Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9200,16 +9200,21 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""

if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
# We can create a definitive split op if the input shape is static
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
if isinstance(self.shape[dim], int):
num_outputs = self.shape[dim]
if num_outputs != 1:
outputs = op.Split(self, axis=dim, num_outputs=num_outputs)
else:
outputs = [self]

return [op.Squeeze(out, [dim]) for out in outputs]
results = []
for i in range(num_outputs):
# Slice to get a single element at position i along dim
sliced = op.Slice(
self,
starts=op.Constant(value_ints=[i]),
ends=op.Constant(value_ints=[i + 1]),
axes=op.Constant(value_ints=[dim]),
)
# Squeeze to remove the dimension of size 1
squeezed = op.Squeeze(sliced, axes=[dim])
results.append(squeezed)
return results

return op.SplitToSequence(self, axis=dim, keepdims=False)

Expand Down
106 changes: 106 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,112 @@ def forward(self, x):
)
_testing.assert_onnx_program(onnx_program)

def test_unbind_dim0(self):
"""Test unbind along dimension 0"""

class UnbindModel(torch.nn.Module):
def forward(self, x):
tensors = torch.unbind(x, dim=0)
return sum(tensors)

model = UnbindModel()
x = torch.randn(3, 4, 5)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_unbind_dim1(self):
"""Test unbind along dimension 1"""

class UnbindModel(torch.nn.Module):
def forward(self, x):
tensors = torch.unbind(x, dim=1)
return sum(tensors)

model = UnbindModel()
x = torch.randn(2, 3, 4)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_unbind_negative_dim(self):
"""Test unbind with negative dimension"""

class UnbindModel(torch.nn.Module):
def forward(self, x):
tensors = torch.unbind(x, dim=-1)
return sum(tensors)

model = UnbindModel()
x = torch.randn(2, 3, 4)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_unbind_size_one(self):
"""Test unbind with dimension of size 1"""

class UnbindModel(torch.nn.Module):
def forward(self, x):
tensors = torch.unbind(x, dim=0)
return tensors[0]

model = UnbindModel()
x = torch.randn(1, 4, 5)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_unbind_with_lstm(self):
"""Test unbind in LSTM context"""

class LSTMDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(100, 64)
self.lstm = torch.nn.LSTM(64, 64, 2, batch_first=True) # 2 layers
self.fc = torch.nn.Linear(64, 100)

def forward(self, tokens, h, c):
embedded = self.embedding(tokens).unsqueeze(0)
output, (h_out, c_out) = self.lstm(embedded, (h, c))
logits = self.fc(output.squeeze(0).squeeze(0))
return logits, h_out, c_out

model = LSTMDecoder()
model.eval()
tokens = torch.tensor([1])
h = torch.randn(2, 1, 64) # 2 layers
c = torch.randn(2, 1, 64) # 2 layers
onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_unbind_dynamic_dim0(self):
"""Test unbind with dynamic dimension 0 - triggers SplitToSequence"""

class UnbindModel(torch.nn.Module):
def forward(self, x):
tensors = torch.unbind(x, dim=0)
return sum(tensors)

model = UnbindModel()
x = torch.randn(3, 4, 5)
onnx_program = torch.onnx.export(
model, (x,), dynamo=True, verbose=False, dynamic_shapes=({0: "batch_size"},)
)
_testing.assert_onnx_program(onnx_program)

def test_unbind_dynamic_dim1(self):
"""Test unbind with dynamic dimension 1 - triggers SplitToSequence"""

class UnbindModel(torch.nn.Module):
def forward(self, x):
tensors = torch.unbind(x, dim=1)
return sum(tensors)

model = UnbindModel()
x = torch.randn(2, 3, 4)
onnx_program = torch.onnx.export(
model, (x,), dynamo=True, verbose=False, dynamic_shapes=({1: "seq_len"},)
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
Loading