Skip to content

Commit 208ed94

Browse files
authored
Merge branch 'main' into justinchu/debug-pass
2 parents 318025b + 1048faf commit 208ed94

File tree

17 files changed

+672
-237
lines changed

17 files changed

+672
-237
lines changed

docs/intermediate_representation/tensors.md

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -192,64 +192,88 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
192192
import ctypes
193193
from typing import Any
194194
195+
import numpy.typing as npt
195196
import torch
197+
196198
from onnxscript import ir
197199
198-
# Define utilities to convert PyTorch data types so users do not need to specify manually
199-
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
200-
torch.bfloat16: ir.DataType.BFLOAT16,
201-
torch.bool: ir.DataType.BOOL,
202-
torch.complex128: ir.DataType.COMPLEX128,
203-
torch.complex64: ir.DataType.COMPLEX64,
204-
torch.float16: ir.DataType.FLOAT16,
205-
torch.float32: ir.DataType.FLOAT,
206-
torch.float64: ir.DataType.DOUBLE,
207-
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
208-
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
209-
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
210-
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
211-
torch.int16: ir.DataType.INT16,
212-
torch.int32: ir.DataType.INT32,
213-
torch.int64: ir.DataType.INT64,
214-
torch.int8: ir.DataType.INT8,
215-
torch.uint8: ir.DataType.UINT8,
216-
}
217-
218-
219-
def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
220-
return _TORCH_DTYPE_TO_ONNX[dtype]
221200
222201
class TorchTensor(ir.Tensor):
223-
def __init__(self, tensor: torch.Tensor):
202+
def __init__(
203+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
204+
):
224205
# Pass the tensor as the raw data to ir.Tensor's constructor
225-
super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
226206
227-
def __array__(self, dtype: Any = None) -> "np.ndarray":
228-
# numpy() calls __array__ in ir.Tensor
207+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
208+
torch.bfloat16: ir.DataType.BFLOAT16,
209+
torch.bool: ir.DataType.BOOL,
210+
torch.complex128: ir.DataType.COMPLEX128,
211+
torch.complex64: ir.DataType.COMPLEX64,
212+
torch.float16: ir.DataType.FLOAT16,
213+
torch.float32: ir.DataType.FLOAT,
214+
torch.float64: ir.DataType.DOUBLE,
215+
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
216+
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
217+
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
218+
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
219+
torch.int16: ir.DataType.INT16,
220+
torch.int32: ir.DataType.INT32,
221+
torch.int64: ir.DataType.INT64,
222+
torch.int8: ir.DataType.INT8,
223+
torch.uint8: ir.DataType.UINT8,
224+
torch.uint16: ir.DataType.UINT16,
225+
torch.uint32: ir.DataType.UINT32,
226+
torch.uint64: ir.DataType.UINT64,
227+
}
228+
super().__init__(
229+
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
230+
)
231+
232+
def numpy(self) -> npt.NDArray:
233+
self.raw: torch.Tensor
229234
if self.dtype == ir.DataType.BFLOAT16:
230-
return self.raw.view(torch.uint16).__array__(dtype)
235+
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
231236
if self.dtype in {
232237
ir.DataType.FLOAT8E4M3FN,
233238
ir.DataType.FLOAT8E4M3FNUZ,
234239
ir.DataType.FLOAT8E5M2,
235-
ir.DataType.FLOAT8E5M2FNUZ
240+
ir.DataType.FLOAT8E5M2FNUZ,
236241
}:
237-
return self.raw.view(torch.uint8).__array__(dtype)
238-
return self.raw.__array__(dtype)
242+
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
243+
244+
return self.raw.numpy(force=True)
245+
246+
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
247+
del copy # Unused, but needed for the signature
248+
if dtype is None:
249+
return self.numpy()
250+
return self.numpy().__array__(dtype)
239251
240252
def tobytes(self) -> bytes:
241253
# Implement tobytes to support native PyTorch types so we can use types like bloat16
242254
# Reading from memory directly is also more efficient because
243255
# it avoids copying to a NumPy array
244-
tensor = self.raw.detach().cpu().contiguous()
256+
import torch._subclasses.fake_tensor
257+
258+
with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
259+
# Disable any fake mode so calling detach() etc. will return a real tensor
260+
tensor = self.raw.detach().cpu().contiguous()
261+
262+
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
263+
raise TypeError(
264+
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
265+
"with a tensor backed by real data using ONNXProgram.apply_weights() "
266+
"or save the model without initializers by setting include_initializers=False."
267+
)
268+
245269
return bytes(
246270
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
247271
tensor.data_ptr()
248272
)
249273
)
250274
251275
# Test the implementation
252-
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
276+
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
253277
tensor = TorchTensor(torch_tensor)
254278
print("tensor: ", tensor)
255279
print("numpy: ", tensor.numpy())

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 118 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -21,98 +21,33 @@
2121
from onnxscript.onnx_types import TensorType
2222

2323

24-
@torch_op(
25-
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
26-
private=True,
27-
complex=True,
28-
trace_only=True,
29-
)
3024
def _fftn_onnx_normalization(
31-
self,
32-
transformed: TFloat,
25+
self: TFloat,
3326
normalization: int,
34-
forward: bool,
35-
dims: Sequence[int],
36-
) -> TFloat:
37-
# Obtain the total_sample_count (n) for normalization
38-
self_shape = op.Shape(self)
39-
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
40-
total_sample_count = op.CastLike(total_sample_count, transformed)
41-
42-
# Normalize the result
43-
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
44-
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
45-
if normalization == 1:
46-
# "forward" - normalize by 1/n
47-
if forward:
48-
result = op.Div(transformed, op.Sqrt(total_sample_count))
49-
else:
50-
result = op.Mul(transformed, op.Sqrt(total_sample_count))
51-
elif normalization == 2:
52-
# "ortho" - normalize by 1/sqrt(n)
53-
if forward:
54-
result = op.Div(transformed, total_sample_count)
55-
else:
56-
result = transformed
57-
else:
58-
# "backward" - no normalization
59-
if forward:
60-
result = transformed
61-
else:
62-
result = op.Mul(transformed, total_sample_count)
63-
64-
return result
65-
66-
67-
@torch_op(
68-
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
69-
trace_only=True,
70-
private=True,
71-
complex=True,
72-
)
73-
def _fftn_onnx(
74-
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
27+
signal_size: INT64,
28+
inverse: bool = False,
7529
) -> TFloat:
76-
"""Standard complex to complex or real to complex FFT (forward or backward).
77-
78-
This is a private shared function for implementing the various FFT functions.
79-
80-
Args:
81-
self: The input tensor.
82-
dims: The dimensions to apply FFT.
83-
normalization: The normalization mode.
84-
inverse: Whether to compute the inverse FFT.
85-
onesided: Whether to compute the one-sided FFT, which retains only the
86-
positive frequencies.
87-
88-
Returns:
89-
The transformed tensor.
90-
"""
91-
92-
# NOTE: trace_only because we need to process each dimension in a loop
93-
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
94-
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
95-
96-
# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
97-
# dimension at the beginning to represent the batch dimension.
98-
transformed = op.Unsqueeze(self, axes=[0])
99-
100-
# Add 1 to account for the batch dimension when counting axes from the left
101-
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
102-
103-
for dim in new_dims[:-1]:
104-
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
105-
106-
# Torch computers one-sided FFT on the last dimension only.
107-
if onesided:
108-
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
30+
"""Normalize in forward or backward direction."""
31+
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
32+
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
33+
# Modes:
34+
# 0: no normalization (backward)
35+
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
36+
# 2: divide by signal_size (forward)
37+
signal_size = op.CastLike(signal_size, self)
38+
if not inverse:
39+
# Forward normalization
40+
if normalization == 1:
41+
self = op.Div(self, op.Sqrt(signal_size))
42+
elif normalization == 2:
43+
self = op.Div(self, signal_size)
10944
else:
110-
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
111-
112-
# Remove the batch dimension
113-
transformed = op.Squeeze(transformed, axes=[0])
114-
115-
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
45+
# Backward normalization, accounting for op.DFT already dividing by signal_size
46+
if normalization == 0:
47+
self = op.Mul(self, signal_size)
48+
elif normalization == 1:
49+
self = op.Mul(self, op.Sqrt(signal_size))
50+
return self
11651

11752

11853
@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
@@ -124,39 +59,87 @@ def aten__fft_c2c(
12459
Standard complex to complex FFT (forward or backward).
12560
"""
12661

127-
# NOTE: trace_only because we need to negate forward
128-
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
129-
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
62+
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis
13063

13164
# ONNX DFT input assumes the last dimension is the complex dimension.
132-
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
133-
dim = [d - 1 if d < 0 else d for d in dim]
134-
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)
65+
66+
unsqueeze_first_dim = 0 in dim
67+
# 1. Add a new dimension for the end and batch dimension, if needed
68+
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
69+
# If needed, add 1 to account for the batch dimension.
70+
71+
if unsqueeze_first_dim:
72+
transformed = op.Unsqueeze(self, axes=[0])
73+
dim = [d + 1 for d in dim]
74+
else:
75+
transformed = self
76+
77+
for dimension in reversed(dim):
78+
transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
79+
transformed = _fftn_onnx_normalization(
80+
transformed,
81+
normalization,
82+
op.Shape(transformed, start=dimension, end=dimension + 1),
83+
not forward,
84+
)
85+
86+
if unsqueeze_first_dim:
87+
transformed = op.Squeeze(transformed, axes=[0])
88+
89+
return transformed
13590

13691

13792
@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
13893
def aten__fft_c2r(
13994
self: TFloat,
14095
dim: Sequence[int],
14196
normalization: int,
142-
last_dim_size: INT64, # pylint: disable=unused-argument
97+
last_dim_size: INT64,
14398
) -> TFloat:
14499
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
145100
146-
Complex to real inverse FFT.
101+
Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
147102
"""
148-
149-
# TODO(justinchuby): Figure out what last_dim_size does
150-
151-
self_rank = len(self.shape)
152-
# ONNX DFT input assumes the last dimension is the complex dimension.
153-
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
154-
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
155-
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
156-
# Take only the real part
157-
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
158-
159-
return op.Squeeze(real_part, axes=[-1])
103+
if len(dim) != 1:
104+
raise NotImplementedError("Only one dimension is supported for inverse FFT")
105+
106+
dimension = dim[0]
107+
unsqueeze_first_dim = dimension == 0
108+
# 1. Add a new dimension for batch dimension, if needed
109+
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
110+
# If needed, add 1 to account for the batch dimension.
111+
112+
if unsqueeze_first_dim:
113+
transformed = op.Unsqueeze(self, axes=[0])
114+
dimension = 1
115+
else:
116+
transformed = self
117+
118+
# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119+
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120+
# place no such restriction on the ONNX side.
121+
transformed = op.DFT(
122+
transformed,
123+
dft_length=last_dim_size,
124+
axis=dimension,
125+
inverse=True,
126+
onesided=False,
127+
)
128+
transformed = _fftn_onnx_normalization(
129+
transformed,
130+
normalization,
131+
op.Shape(transformed, start=dimension, end=dimension + 1),
132+
inverse=True,
133+
)
134+
135+
if unsqueeze_first_dim:
136+
transformed = op.Squeeze(transformed, axes=[0])
137+
138+
# Remove the imaginary part
139+
transformed = op.Slice(transformed, [0], [1], [-1])
140+
transformed = op.Squeeze(transformed, axes=[-1])
141+
142+
return transformed
160143

161144

162145
@torch_op("aten::_fft_r2c", trace_only=True)
@@ -168,17 +151,37 @@ def aten__fft_r2c(
168151
Real to complex forward FFT.
169152
"""
170153

171-
# Add a new dimension at the end
172-
signal = op.Unsqueeze(self, axes=[-1])
173154
# No need to fill the imaginary part because ONNX DFT accepts real inputs
174155
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
175156

176-
self_rank = len(self.shape)
177-
# ONNX DFT input assumes the last dimension is the complex dimension.
178-
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
179-
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
157+
unsqueeze_first_dim = 0 in dim
158+
# 1. Add a new dimension for the end and batch dimension, if needed
159+
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
160+
# If needed, add 1 to account for the batch dimension.
161+
162+
if unsqueeze_first_dim:
163+
transformed = op.Unsqueeze(self, axes=[0, -1])
164+
dim = [d + 1 for d in dim]
165+
else:
166+
transformed = op.Unsqueeze(self, axes=[-1])
167+
168+
for idx, dimension in enumerate(reversed(dim)):
169+
transformed = _fftn_onnx_normalization(
170+
transformed,
171+
normalization,
172+
op.Shape(transformed, start=dimension, end=dimension + 1),
173+
inverse=False,
174+
)
175+
if idx > 0:
176+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
177+
else:
178+
# Torch computes one-sided FFT on the last dimension only.
179+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided)
180+
181+
if unsqueeze_first_dim:
182+
transformed = op.Squeeze(transformed, axes=[0])
180183

181-
return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
184+
return transformed
182185

183186

184187
def aten_fft_fft(

0 commit comments

Comments
 (0)