diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4d83675589..b287cec057 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9200,16 +9200,21 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): - # We can create a definitive split op if the input shape is static - # Only torch>=2.7 supports correctly generating the correct number of outputs for Split + if isinstance(self.shape[dim], int): num_outputs = self.shape[dim] - if num_outputs != 1: - outputs = op.Split(self, axis=dim, num_outputs=num_outputs) - else: - outputs = [self] - - return [op.Squeeze(out, [dim]) for out in outputs] + results = [] + for i in range(num_outputs): + # Slice to get a single element at position i along dim + sliced = op.Slice( + self, + starts=op.Constant(value_ints=[i]), + ends=op.Constant(value_ints=[i + 1]), + axes=op.Constant(value_ints=[dim]), + ) + # Squeeze to remove the dimension of size 1 + squeezed = op.Squeeze(sliced, axes=[dim]) + results.append(squeezed) + return results return op.SplitToSequence(self, axis=dim, keepdims=False) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 1546de59bd..a2ced58c44 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -520,6 +520,112 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_unbind_dim0(self): + """Test unbind along dimension 0""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dim1(self): + """Test unbind along dimension 1""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_negative_dim(self): + """Test unbind with negative dimension""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=-1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_size_one(self): + """Test unbind with dimension of size 1""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return tensors[0] + + model = UnbindModel() + x = torch.randn(1, 4, 5) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_with_lstm(self): + """Test unbind in LSTM context""" + + class LSTMDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(100, 64) + self.lstm = torch.nn.LSTM(64, 64, 2, batch_first=True) # 2 layers + self.fc = torch.nn.Linear(64, 100) + + def forward(self, tokens, h, c): + embedded = self.embedding(tokens).unsqueeze(0) + output, (h_out, c_out) = self.lstm(embedded, (h, c)) + logits = self.fc(output.squeeze(0).squeeze(0)) + return logits, h_out, c_out + + model = LSTMDecoder() + model.eval() + tokens = torch.tensor([1]) + h = torch.randn(2, 1, 64) # 2 layers + c = torch.randn(2, 1, 64) # 2 layers + onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dynamic_dim0(self): + """Test unbind with dynamic dimension 0 - triggers SplitToSequence""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export( + model, (x,), dynamo=True, verbose=False, dynamic_shapes=({0: "batch_size"},) + ) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dynamic_dim1(self): + """Test unbind with dynamic dimension 1 - triggers SplitToSequence""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export( + model, (x,), dynamo=True, verbose=False, dynamic_shapes=({1: "seq_len"},) + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main()