Skip to content

Commit 8c176f2

Browse files
authored
Trace single op.SplitToSequence (#2817)
This pull request updates several operator registrations in `onnxscript/function_libs/torch_lib/ops/core.py` to mark them as `trace_only`, which affects how these ops are handled during tracing and export. The most important changes are: Operator registration updates (trace-only): * Marked the `aten::split` and `aten::split.Tensor` operators as `trace_only` in the `aten_split` function registration. * Marked the `aten::split_with_sizes` operator as `trace_only` in the `aten_split_with_sizes` function registration. * Marked the `aten::unsafe_split.Tensor` operator as `trace_only` in the `aten_unsafe_split` function registration.
1 parent 0d06d3b commit 8c176f2

File tree

1 file changed

+3
-3
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+3
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9068,7 +9068,7 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType:
90689068
raise NotImplementedError()
90699069

90709070

9071-
@torch_op(("aten::split", "aten::split.Tensor"))
9071+
@torch_op(("aten::split", "aten::split.Tensor"), trace_only=True)
90729072
def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
90739073
"""split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]"""
90749074

@@ -9081,7 +9081,7 @@ def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> Tensor
90819081
raise NotImplementedError()
90829082

90839083

9084-
@torch_op("aten::split_with_sizes")
9084+
@torch_op(("aten::split_with_sizes",), trace_only=True)
90859085
def aten_split_with_sizes(self: TTensor, split_sizes: INT64, dim: int = 0) -> TTensor:
90869086
"""split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]"""
90879087

@@ -10101,7 +10101,7 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType
1010110101
raise NotImplementedError()
1010210102

1010310103

10104-
@torch_op("aten::unsafe_split.Tensor")
10104+
@torch_op("aten::unsafe_split.Tensor", trace_only=True)
1010510105
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]:
1010610106
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""
1010710107

0 commit comments

Comments
 (0)