Skip to content

Commit ae8438a

Browse files
committed
Merge branch 'main' into justinchu/remove-legacy
2 parents 60f0c1b + d7955f4 commit ae8438a

28 files changed

+1164
-344
lines changed

.github/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ changelog:
1818
- title: ONNX IR
1919
labels:
2020
- "module: IR"
21+
- "topic: passes"
2122
- title: Torch Lib
2223
labels:
2324
- "module: torchlib"
24-
- "topic: passes"
2525
- title: Documentation
2626
labels:
2727
- "topic: documentation"

docs/intermediate_representation/tensors.md

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -192,64 +192,88 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
192192
import ctypes
193193
from typing import Any
194194
195+
import numpy.typing as npt
195196
import torch
197+
196198
from onnxscript import ir
197199
198-
# Define utilities to convert PyTorch data types so users do not need to specify manually
199-
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
200-
torch.bfloat16: ir.DataType.BFLOAT16,
201-
torch.bool: ir.DataType.BOOL,
202-
torch.complex128: ir.DataType.COMPLEX128,
203-
torch.complex64: ir.DataType.COMPLEX64,
204-
torch.float16: ir.DataType.FLOAT16,
205-
torch.float32: ir.DataType.FLOAT,
206-
torch.float64: ir.DataType.DOUBLE,
207-
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
208-
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
209-
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
210-
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
211-
torch.int16: ir.DataType.INT16,
212-
torch.int32: ir.DataType.INT32,
213-
torch.int64: ir.DataType.INT64,
214-
torch.int8: ir.DataType.INT8,
215-
torch.uint8: ir.DataType.UINT8,
216-
}
217-
218-
219-
def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
220-
return _TORCH_DTYPE_TO_ONNX[dtype]
221200
222201
class TorchTensor(ir.Tensor):
223-
def __init__(self, tensor: torch.Tensor):
202+
def __init__(
203+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
204+
):
224205
# Pass the tensor as the raw data to ir.Tensor's constructor
225-
super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
226206
227-
def __array__(self, dtype: Any = None) -> "np.ndarray":
228-
# numpy() calls __array__ in ir.Tensor
207+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
208+
torch.bfloat16: ir.DataType.BFLOAT16,
209+
torch.bool: ir.DataType.BOOL,
210+
torch.complex128: ir.DataType.COMPLEX128,
211+
torch.complex64: ir.DataType.COMPLEX64,
212+
torch.float16: ir.DataType.FLOAT16,
213+
torch.float32: ir.DataType.FLOAT,
214+
torch.float64: ir.DataType.DOUBLE,
215+
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
216+
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
217+
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
218+
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
219+
torch.int16: ir.DataType.INT16,
220+
torch.int32: ir.DataType.INT32,
221+
torch.int64: ir.DataType.INT64,
222+
torch.int8: ir.DataType.INT8,
223+
torch.uint8: ir.DataType.UINT8,
224+
torch.uint16: ir.DataType.UINT16,
225+
torch.uint32: ir.DataType.UINT32,
226+
torch.uint64: ir.DataType.UINT64,
227+
}
228+
super().__init__(
229+
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
230+
)
231+
232+
def numpy(self) -> npt.NDArray:
233+
self.raw: torch.Tensor
229234
if self.dtype == ir.DataType.BFLOAT16:
230-
return self.raw.view(torch.uint16).__array__(dtype)
235+
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
231236
if self.dtype in {
232237
ir.DataType.FLOAT8E4M3FN,
233238
ir.DataType.FLOAT8E4M3FNUZ,
234239
ir.DataType.FLOAT8E5M2,
235-
ir.DataType.FLOAT8E5M2FNUZ
240+
ir.DataType.FLOAT8E5M2FNUZ,
236241
}:
237-
return self.raw.view(torch.uint8).__array__(dtype)
238-
return self.raw.__array__(dtype)
242+
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
243+
244+
return self.raw.numpy(force=True)
245+
246+
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
247+
del copy # Unused, but needed for the signature
248+
if dtype is None:
249+
return self.numpy()
250+
return self.numpy().__array__(dtype)
239251
240252
def tobytes(self) -> bytes:
241253
# Implement tobytes to support native PyTorch types so we can use types like bloat16
242254
# Reading from memory directly is also more efficient because
243255
# it avoids copying to a NumPy array
244-
tensor = self.raw.detach().cpu().contiguous()
256+
import torch._subclasses.fake_tensor
257+
258+
with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
259+
# Disable any fake mode so calling detach() etc. will return a real tensor
260+
tensor = self.raw.detach().cpu().contiguous()
261+
262+
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
263+
raise TypeError(
264+
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
265+
"with a tensor backed by real data using ONNXProgram.apply_weights() "
266+
"or save the model without initializers by setting include_initializers=False."
267+
)
268+
245269
return bytes(
246270
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
247271
tensor.data_ptr()
248272
)
249273
)
250274
251275
# Test the implementation
252-
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
276+
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
253277
tensor = TorchTensor(torch_tensor)
254278
print("tensor: ", tensor)
255279
print("numpy: ", tensor.numpy())

0 commit comments

Comments
 (0)