Skip to content

Commit 8e0e86b

Browse files
authored
Add LazyTensor class to implement ir.TensorProtocol (#2232)
I used copilot to help implement #2231. The lazy tensor class allows users to delay transformations to the tensors until serialization time, which helps with memory usage and avoids the need to cache of unload intermediate tensor data to disk. Example ```py >>> import numpy as np >>> from onnxscript import ir >>> weights = np.array([[1, 2, 3]]) >>> def create_tensor(): ... # Delay applying transformations to the weights ... weights_t = weights.transpose() ... return ir.tensor(weights_t) >>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3])) >>> print(lazy_tensor.numpy()) [[1] [2] [3]] >>> print(lazy_tensor.tobytes()) b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00' ``` Fixes #2231 --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/microsoft/onnxscript/pull/2232?shareId=b91d512a-8d84-4aca-8545-899243396be5).
1 parent 1d7aea3 commit 8e0e86b

4 files changed

Lines changed: 178 additions & 2 deletions

File tree

docs/ir/ir_api/core.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ir.Tensor
4949
ir.ExternalTensor
5050
ir.StringTensor
51+
ir.LazyTensor
5152
```
5253

5354
## Enums

onnxscript/ir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"Tensor",
1414
"ExternalTensor",
1515
"StringTensor",
16+
"LazyTensor",
1617
"SymbolicDim",
1718
"Shape",
1819
"TensorType",
@@ -104,6 +105,7 @@
104105
Graph,
105106
GraphView,
106107
Input,
108+
LazyTensor,
107109
Model,
108110
Node,
109111
OptionalType,

onnxscript/ir/_core.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import (
2727
AbstractSet,
2828
Any,
29+
Callable,
2930
Collection,
3031
Generic,
3132
Iterable,
@@ -113,7 +114,7 @@ def _repr_base(self) -> str:
113114
@property
114115
def size(self) -> int:
115116
"""The number of elements in the tensor."""
116-
return np.prod(self.shape.numpy()) # type: ignore[return-value,attr-defined]
117+
return math.prod(self.shape.numpy()) # type: ignore[attr-defined]
117118

118119
@property
119120
def nbytes(self) -> int:
@@ -853,6 +854,145 @@ def meta(self) -> _metadata.MetadataStore:
853854
return self._metadata
854855

855856

857+
class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
858+
"""A tensor that lazily evaluates a function to get the actual tensor.
859+
860+
This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument.
861+
The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called.
862+
863+
Example::
864+
865+
>>> import numpy as np
866+
>>> from onnxscript import ir
867+
>>> weights = np.array([[1, 2, 3]])
868+
>>> def create_tensor(): # Delay applying transformations to the weights
869+
... weights_t = weights.transpose()
870+
... return ir.tensor(weights_t)
871+
>>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3]))
872+
>>> print(lazy_tensor.numpy())
873+
[[1]
874+
[2]
875+
[3]]
876+
877+
Attributes:
878+
func: The function that returns the actual tensor.
879+
dtype: The data type of the tensor.
880+
shape: The shape of the tensor.
881+
cache: Whether to cache the result of the function. If False,
882+
the function is called every time the tensor content is accessed.
883+
If True, the function is called only once and the result is cached in memory.
884+
Default is False.
885+
name: The name of the tensor.
886+
doc_string: The documentation string.
887+
metadata_props: The metadata properties.
888+
"""
889+
890+
__slots__ = (
891+
"_dtype",
892+
"_func",
893+
"_metadata",
894+
"_metadata_props",
895+
"_shape",
896+
"_tensor",
897+
"cache",
898+
"doc_string",
899+
"name",
900+
)
901+
902+
def __init__(
903+
self,
904+
func: Callable[[], _protocols.TensorProtocol],
905+
dtype: _enums.DataType,
906+
shape: Shape,
907+
*,
908+
cache: bool = False,
909+
name: str | None = None,
910+
doc_string: str | None = None,
911+
metadata_props: dict[str, str] | None = None,
912+
) -> None:
913+
"""Initialize a lazy tensor.
914+
915+
Args:
916+
func: The function that returns the actual tensor.
917+
dtype: The data type of the tensor.
918+
shape: The shape of the tensor.
919+
cache: Whether to cache the result of the function.
920+
name: The name of the tensor.
921+
doc_string: The documentation string.
922+
metadata_props: The metadata properties.
923+
"""
924+
self._func = func
925+
self._dtype = dtype
926+
self._shape = shape
927+
self._tensor: _protocols.TensorProtocol | None = None
928+
self.cache = cache
929+
self.name = name
930+
self.doc_string = doc_string
931+
self._metadata: _metadata.MetadataStore | None = None
932+
self._metadata_props = metadata_props
933+
934+
def _evaluate(self) -> _protocols.TensorProtocol:
935+
"""Evaluate the function to get the actual tensor."""
936+
if not self.cache:
937+
return self._func()
938+
939+
# Cache the tensor
940+
if self._tensor is None:
941+
self._tensor = self._func()
942+
return self._tensor
943+
944+
def __array__(self, dtype: Any = None) -> np.ndarray:
945+
return self._evaluate().__array__(dtype)
946+
947+
def __dlpack__(self, *, stream: Any = None) -> Any:
948+
return self._evaluate().__dlpack__(stream=stream)
949+
950+
def __dlpack_device__(self) -> tuple[int, int]:
951+
return self._evaluate().__dlpack_device__()
952+
953+
def __repr__(self) -> str:
954+
return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"
955+
956+
@property
957+
def raw(self) -> Callable[[], _protocols.TensorProtocol]:
958+
return self._func
959+
960+
@property
961+
def dtype(self) -> _enums.DataType:
962+
"""The data type of the tensor. Immutable."""
963+
return self._dtype
964+
965+
@property
966+
def shape(self) -> Shape:
967+
"""The shape of the tensor. Immutable."""
968+
return self._shape
969+
970+
def numpy(self) -> np.ndarray:
971+
"""Return the tensor as a numpy array."""
972+
return self._evaluate().numpy()
973+
974+
def tobytes(self) -> bytes:
975+
"""Return the bytes of the tensor."""
976+
return self._evaluate().tobytes()
977+
978+
@property
979+
def metadata_props(self) -> dict[str, str]:
980+
if self._metadata_props is None:
981+
self._metadata_props = {}
982+
return self._metadata_props
983+
984+
@property
985+
def meta(self) -> _metadata.MetadataStore:
986+
"""The metadata store for intermediate analysis.
987+
988+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
989+
to the ONNX proto.
990+
"""
991+
if self._metadata is None:
992+
self._metadata = _metadata.MetadataStore()
993+
return self._metadata
994+
995+
856996
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
857997
__slots__ = ("_value",)
858998

@@ -2183,7 +2323,7 @@ def sort(self) -> None:
21832323
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
21842324
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
21852325
}
2186-
# TODO: Explain why we need to store direct predecessors and children and why
2326+
# TODO(justinchuby): Explain why we need to store direct predecessors and children and why
21872327
# we only need to store the direct ones
21882328

21892329
# The depth of a node is defined as the number of direct children it has

onnxscript/ir/_core_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,5 +1312,38 @@ def test_as_graphs(self):
13121312
self.assertIsInstance(attr.as_graphs()[0], _core.Graph)
13131313

13141314

1315+
class LazyTensorTest(unittest.TestCase):
1316+
def test_lazy_tensor_initialization(self):
1317+
def tensor_fn():
1318+
return ir.tensor([1, 2, 3], dtype=ir.DataType.INT64)
1319+
1320+
lazy_tensor = _core.LazyTensor(
1321+
tensor_fn, dtype=ir.DataType.INT64, shape=ir.Shape((3,))
1322+
)
1323+
self.assertEqual(lazy_tensor.dtype, ir.DataType.INT64)
1324+
self.assertEqual(lazy_tensor.shape, (3,))
1325+
1326+
def test_lazy_tensor_numpy(self):
1327+
def tensor_fn():
1328+
return ir.tensor([1, 2, 3], dtype=ir.DataType.INT64)
1329+
1330+
lazy_tensor = _core.LazyTensor(
1331+
tensor_fn, dtype=ir.DataType.INT64, shape=ir.Shape((3,))
1332+
)
1333+
np.testing.assert_array_equal(lazy_tensor.numpy(), np.array([1, 2, 3]))
1334+
1335+
def test_lazy_tensor_tobytes(self):
1336+
def tensor_fn():
1337+
return ir.tensor([1, 2, 3], dtype=ir.DataType.INT64)
1338+
1339+
lazy_tensor = _core.LazyTensor(
1340+
tensor_fn, dtype=ir.DataType.INT64, shape=ir.Shape((3,))
1341+
)
1342+
self.assertEqual(
1343+
lazy_tensor.tobytes(),
1344+
b"\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00",
1345+
)
1346+
1347+
13151348
if __name__ == "__main__":
13161349
unittest.main()

0 commit comments

Comments
 (0)