Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations

from collections.abc import Sequence

import numpy.typing as npt
import onnx

Expand Down Expand Up @@ -78,3 +80,32 @@ def constant(
A constant node.
"""
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))


def merge_dims(dims: Sequence[int | INT64]) -> INT64:
Comment thread
justinchuby marked this conversation as resolved.
"""Merge consecutive constant dimensions."""

if not dims:
return op.Constant(value_ints=[])

remaining_dims = list(dims)
result_dims = []

while remaining_dims:
current_dim = remaining_dims.pop(0)
if isinstance(current_dim, int):
merged_dims = [current_dim]
# Merge consecutive constant dimensions into a constant node
while remaining_dims and isinstance(remaining_dims[0], int):
merged_dims.append(remaining_dims.pop(0))
result_dims.append(op.Constant(value_ints=merged_dims))
else:
# A dynamic dimension, just append it
result_dims.append(current_dim)
if len(result_dims) == 1:
return result_dims[0]

# Set the output type to INT64 so op.Concat can be used
for dim in result_dims:
dim.dtype = ir.DataType.INT64
Comment thread
justinchuby marked this conversation as resolved.
Outdated
return op.Concat(*result_dims, axis=0)
67 changes: 32 additions & 35 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,10 +1523,10 @@
raise NotImplementedError()


@torch_op("aten::broadcast_to")
def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor:
@torch_op("aten::broadcast_to", trace_only=True)
def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = common_ops.merge_dims(size)
return op.Expand(self, size)


Expand Down Expand Up @@ -3286,20 +3286,20 @@

@torch_op("aten::empty.memory_format", trace_only=True)
def aten_empty(
size: IntType,
size: Sequence[INT64],
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
pin_memory: bool = False,
memory_format: str = "",
) -> TensorType: # type: ignore[type-var]
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
"""empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
if dtype == -1:
dtype = FLOAT.dtype
# using Zeros to simulate np.empty()
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)
zero = op.Cast(zero, to=dtype)

# using Zeros to simulate empty()
zero = op.Constant(value=ir.tensor(0, dtype=dtype))
size = common_ops.merge_dims(size)

return op.Expand(zero, size)

Expand Down Expand Up @@ -3334,7 +3334,7 @@

@torch_op("aten::empty_strided", trace_only=True)
def aten_empty_strided(
size: INT64,
size: Sequence[INT64],
stride: INT64,
layout: str = "",
device: str = "",
Expand All @@ -3343,8 +3343,8 @@
# empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

# using Zeros to simulate empty()
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)
zero = op.Constant(value=ir.tensor(0, dtype=dtype))
Comment thread Fixed
Comment thread Fixed
size = common_ops.merge_dims(size)

return op.Expand(zero, size)

Expand Down Expand Up @@ -3392,13 +3392,12 @@


@torch_op("aten::expand", trace_only=True)
def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor:
def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor:
"""expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
size = op.Cast(size, to=INT64.dtype)
# NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1.
# To support -1 dim, we need to convert -1 to 1.
size = op.Abs(size)
return op.Expand(self, size)
size = [1 if s == -1 else s for s in size]
Comment thread
justinchuby marked this conversation as resolved.
Outdated
return op.Expand(self, common_ops.merge_dims(size))


@torch_op("aten::expand_as", trace_only=True)
Expand Down Expand Up @@ -7300,12 +7299,10 @@
raise NotImplementedError()


@torch_op("aten::reshape")
def aten_reshape(self: TTensor, shape: IntType) -> TTensor:
@torch_op("aten::reshape", trace_only=True)
def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor:
"""reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""

# Reshape only support INT64 as 'shape'
shape = op.Cast(shape, to=INT64.dtype)
shape = common_ops.merge_dims(shape)
return op.Reshape(self, shape)


Expand Down Expand Up @@ -9045,23 +9042,22 @@


@torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True)
def aten_view(self: TTensor, size: IntType) -> TTensor:
def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
size = common_ops.merge_dims(size)
return op.Reshape(self, size, allowzero=True)


@torch_op(("aten::view", "aten::_unsafe_view"), complex=True)
def aten_view_complex(self: TTensor, size: IntType) -> TTensor:
@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True)
def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor:
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0)
complex_size = common_ops.merge_dims([*size, 2])
return op.Reshape(self, complex_size, allowzero=True)


@torch_op("aten::view_as")
@torch_op("aten::view_as", trace_only=True)
def aten_view_as(self: TTensor, other: TTensor2) -> TTensor:
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""

Expand Down Expand Up @@ -9105,11 +9101,11 @@
return op.Identity(self)


@torch_op("aten::view_copy")
@torch_op("aten::view_copy", trace_only=True)
def aten_view_copy(self: TTensor, size: IntType) -> TTensor:
Comment thread
justinchuby marked this conversation as resolved.
Outdated
"""view_copy(Tensor self, SymInt[] size) -> Tensor"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
size = common_ops.merge_dims(size)
return op.Reshape(self, size)


Expand Down Expand Up @@ -9137,7 +9133,8 @@
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
"aten::where.self",
)
),
trace_only=True,
)
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""
Expand All @@ -9153,7 +9150,7 @@

@torch_op("aten::zeros", trace_only=True)
def aten_zeros(
size: IntType,
size: Sequence[INT64],
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
Expand All @@ -9162,9 +9159,9 @@
"""zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
if dtype == -1:
dtype = FLOAT.dtype
size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0.0)
zero = op.Cast(zero, to=dtype)

zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype)))
size = common_ops.merge_dims(size)

return op.Expand(zero, size)

Expand Down
Loading