-
Notifications
You must be signed in to change notification settings - Fork 109
Expand file tree
/
Copy pathcommon.py
More file actions
114 lines (87 loc) · 3.81 KB
/
common.py
File metadata and controls
114 lines (87 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Common operators shared in the torchlib library."""
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations
from collections.abc import Sequence
import numpy.typing as npt
import onnx
import onnxscript
import onnxscript.values
from onnxscript import BOOL, INT64, ir
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType
COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype
DOMAIN = f"{_constants.DOMAIN}.common"
common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1)
@onnxscript.script(common_opset)
def Rank(input: tensor_typing.TTensor) -> INT64:
"""Take the rank of the input tensor."""
return op.Size(op.Shape(input))
@onnxscript.script(common_opset)
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
"""Return whether the input has rank 0, or is a scalar."""
return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))
def cast_to(a: RealType, dtype: int) -> RealType:
"""Cast input to dtype while handling complex types."""
# Traced function because different if branches return different dtypes
# which is not supported in an ONNX function
if dtype == COMPLEX128_TYPE:
# Cast to the real representation of the complex type
casted = op.Cast(a, to=DOUBLE.dtype)
# Create a complex number
real_part = op.Unsqueeze(casted, axes=[-1])
imag_part = op.Expand(op.Cast(0.0, to=DOUBLE.dtype), op.Shape(real_part))
result = op.Concat(real_part, imag_part, axis=-1)
elif dtype == COMPLEX64_TYPE:
# Cast to the real representation of the complex type
casted = op.Cast(a, to=FLOAT.dtype)
# Create a complex number
real_part = op.Unsqueeze(casted, axes=[-1])
imag_part = op.Expand(0.0, op.Shape(real_part))
result = op.Concat(real_part, imag_part, axis=-1)
else:
# Cast to real numbers
result = op.Cast(a, to=dtype)
return result
def constant(
array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
dtype: int | onnx.TensorProto.DataType | ir.DataType,
) -> TensorType:
"""Utility for creating a constant tensor.
Args:
array: The array to convert to a constant tensor.
dtype: The data type of the tensor.
Returns:
A constant node.
"""
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))
def merge_dims(dims: Sequence[int | INT64]) -> INT64:
"""Merge consecutive constant dimensions."""
if not dims:
return op.Constant(value_ints=ir.AttrInt64s("value_ints", []))
remaining_dims = list(dims)
result_dims = []
while remaining_dims:
current_dim = remaining_dims.pop(0)
if isinstance(current_dim, int):
merged_dims = [current_dim]
# Merge consecutive constant dimensions into a constant node
while remaining_dims and isinstance(remaining_dims[0], int):
merged_dims.append(remaining_dims.pop(0))
result_dims.append(op.Constant(value_ints=merged_dims))
else:
# A dynamic dimension, unsqueeze and append it
current_dim = op.Reshape(
current_dim, op.Constant(value_ints=ir.AttrInt64s("value_ints", [1]))
)
result_dims.append(current_dim)
if len(result_dims) == 1:
return result_dims[0]
# Set the output type to INT64 so op.Concat can be used
for dim in result_dims:
dim.dtype = ir.DataType.INT64
return op.Concat(*result_dims, axis=0)