@@ -192,64 +192,88 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
192192 import ctypes
193193 from typing import Any
194194
195+ import numpy.typing as npt
195196 import torch
197+
196198 from onnxscript import ir
197199
198- # Define utilities to convert PyTorch data types so users do not need to specify manually
199- _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
200- torch.bfloat16: ir.DataType.BFLOAT16,
201- torch.bool: ir.DataType.BOOL,
202- torch.complex128: ir.DataType.COMPLEX128,
203- torch.complex64: ir.DataType.COMPLEX64,
204- torch.float16: ir.DataType.FLOAT16,
205- torch.float32: ir.DataType.FLOAT,
206- torch.float64: ir.DataType.DOUBLE,
207- torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
208- torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
209- torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
210- torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
211- torch.int16: ir.DataType.INT16,
212- torch.int32: ir.DataType.INT32,
213- torch.int64: ir.DataType.INT64,
214- torch.int8: ir.DataType.INT8,
215- torch.uint8: ir.DataType.UINT8,
216- }
217-
218-
219- def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
220- return _TORCH_DTYPE_TO_ONNX[dtype]
221200
222201 class TorchTensor(ir.Tensor):
223- def __init__(self, tensor: torch.Tensor):
202+ def __init__(
203+ self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
204+ ):
224205 # Pass the tensor as the raw data to ir.Tensor's constructor
225- super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
226206
227- def __array__(self, dtype: Any = None) -> "np.ndarray":
228- # numpy() calls __array__ in ir.Tensor
207+ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
208+ torch.bfloat16: ir.DataType.BFLOAT16,
209+ torch.bool: ir.DataType.BOOL,
210+ torch.complex128: ir.DataType.COMPLEX128,
211+ torch.complex64: ir.DataType.COMPLEX64,
212+ torch.float16: ir.DataType.FLOAT16,
213+ torch.float32: ir.DataType.FLOAT,
214+ torch.float64: ir.DataType.DOUBLE,
215+ torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
216+ torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
217+ torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
218+ torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
219+ torch.int16: ir.DataType.INT16,
220+ torch.int32: ir.DataType.INT32,
221+ torch.int64: ir.DataType.INT64,
222+ torch.int8: ir.DataType.INT8,
223+ torch.uint8: ir.DataType.UINT8,
224+ torch.uint16: ir.DataType.UINT16,
225+ torch.uint32: ir.DataType.UINT32,
226+ torch.uint64: ir.DataType.UINT64,
227+ }
228+ super().__init__(
229+ tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
230+ )
231+
232+ def numpy(self) -> npt.NDArray:
233+ self.raw: torch.Tensor
229234 if self.dtype == ir.DataType.BFLOAT16:
230- return self.raw.view(torch.uint16).__array__( dtype)
235+ return self.raw.view(torch.uint16).numpy(force=True).view(self. dtype.numpy() )
231236 if self.dtype in {
232237 ir.DataType.FLOAT8E4M3FN,
233238 ir.DataType.FLOAT8E4M3FNUZ,
234239 ir.DataType.FLOAT8E5M2,
235- ir.DataType.FLOAT8E5M2FNUZ
240+ ir.DataType.FLOAT8E5M2FNUZ,
236241 }:
237- return self.raw.view(torch.uint8).__array__(dtype)
238- return self.raw.__array__(dtype)
242+ return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
243+
244+ return self.raw.numpy(force=True)
245+
246+ def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
247+ del copy # Unused, but needed for the signature
248+ if dtype is None:
249+ return self.numpy()
250+ return self.numpy().__array__(dtype)
239251
240252 def tobytes(self) -> bytes:
241253 # Implement tobytes to support native PyTorch types so we can use types like bloat16
242254 # Reading from memory directly is also more efficient because
243255 # it avoids copying to a NumPy array
244- tensor = self.raw.detach().cpu().contiguous()
256+ import torch._subclasses.fake_tensor
257+
258+ with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
259+ # Disable any fake mode so calling detach() etc. will return a real tensor
260+ tensor = self.raw.detach().cpu().contiguous()
261+
262+ if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
263+ raise TypeError(
264+ f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
265+ "with a tensor backed by real data using ONNXProgram.apply_weights() "
266+ "or save the model without initializers by setting include_initializers=False."
267+ )
268+
245269 return bytes(
246270 (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
247271 tensor.data_ptr()
248272 )
249273 )
250274
251275 # Test the implementation
252- torch_tensor = torch.tensor([1,2, 3], dtype=torch.bfloat16)
276+ torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
253277 tensor = TorchTensor(torch_tensor)
254278 print("tensor: ", tensor)
255279 print("numpy: ", tensor.numpy())
0 commit comments