Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f4da452
tracked list
justinchuby May 7, 2025
e64efe1
[IR] Record owning graph for input/output and initializers
justinchuby May 7, 2025
d95d276
GraphOutputs
justinchuby May 7, 2025
89ef173
format
justinchuby May 7, 2025
3c859c9
no init
justinchuby May 7, 2025
eaf0ca6
owning_graph
justinchuby May 7, 2025
510d0b9
quote the type
justinchuby May 7, 2025
a5ac719
# pylint: disable=protected-access
justinchuby May 7, 2025
24c7a42
core
justinchuby May 7, 2025
f626f07
init
justinchuby May 7, 2025
e0e6f0a
Update onnxscript/ir/_core.py
justinchuby May 7, 2025
e80de25
GraphInitializers
justinchuby May 8, 2025
b6a0fe0
owning_graph
justinchuby May 8, 2025
4b48d0d
docs
justinchuby May 8, 2025
847b48c
quote
justinchuby May 8, 2025
8e72931
syntax
justinchuby May 8, 2025
6cf7883
Rename
justinchuby May 8, 2025
41db1b2
Rename to graph to match node
justinchuby May 8, 2025
078074e
wip tests
justinchuby May 8, 2025
45898a3
Fix graph
justinchuby May 8, 2025
f1b330c
test
justinchuby May 8, 2025
6c76fb3
Check
justinchuby May 8, 2025
a4e2fc7
More tests
justinchuby May 8, 2025
751db58
wip
justinchuby May 8, 2025
66dfdb2
Data structures
justinchuby May 8, 2025
6431711
tests
justinchuby May 8, 2025
8a1635d
Apply suggestions from code review
justinchuby May 8, 2025
1108dc0
logger
justinchuby May 8, 2025
e6aa051
Fix if
justinchuby May 8, 2025
cb226dd
Fix tests
justinchuby May 8, 2025
91992b0
logger
justinchuby May 8, 2025
3739ded
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby May 8, 2025
f467daf
Fix __getitem__
justinchuby May 8, 2025
9397c46
Use booleans
justinchuby May 8, 2025
4c3afc8
test
justinchuby May 8, 2025
42d678c
ref counter
justinchuby May 8, 2025
0933963
RuntimeError
justinchuby May 8, 2025
22096b4
test
justinchuby May 8, 2025
2f62c50
Fix constant lifting
justinchuby May 8, 2025
de6ad6f
Update onnxscript/ir/_graph_containers.py
justinchuby May 8, 2025
added12
Fix test
justinchuby May 9, 2025
dc0b8e2
typing
justinchuby May 9, 2025
6964109
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
Returns:
A dictionary mapping names to values.
"""
values = {}
values: dict[str, _core.Value] = {}
values.update(graph.initializers)
# The names of the values can be None or "", which we need to exclude
for input in graph.inputs:
Expand Down
86 changes: 62 additions & 24 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
Generic,
Iterable,
Iterator,
MutableMapping,
MutableSequence,
NamedTuple,
OrderedDict,
Sequence,
Expand All @@ -46,6 +48,7 @@
from onnxscript.ir import (
_display,
_enums,
_graph_containers,
_linked_list,
_metadata,
_name_authority,
Expand Down Expand Up @@ -1746,18 +1749,19 @@

To find all the nodes that use this value as an input, call :meth:`uses`.

To check if the value is an output of a graph, call :meth:`is_graph_output`.
To check if the value is an is an input, output or initializer of a graph,
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.

Attributes:
name: The name of the value. A value is always named when it is part of a graph.
shape: The shape of the value.
type: The type of the value.
metadata_props: Metadata.
Use :meth:`graph` to get the graph that owns the value.
"""

__slots__ = (
"_const_value",
"_graph",
"_index",
"_is_graph_input",
"_is_graph_output",
"_is_initializer",
"_metadata",
"_metadata_props",
"_name",
Expand Down Expand Up @@ -1808,6 +1812,14 @@
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

# The graph this value belongs to. It is set *only* when the value is added as
# a graph input, output or initializer.
# The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
self._graph: Graph | None = None
self._is_graph_input: bool = False
self._is_graph_output: bool = False
self._is_initializer: bool = False

def __repr__(self) -> str:
value_name = self.name if self.name else "anonymous:" + str(id(self))
type_text = f", type={self.type!r}" if self.type is not None else ""
Expand Down Expand Up @@ -1846,11 +1858,35 @@
return f"{{{self.const_value.__class__.__name__}(...)}}"
return ""

@property
def graph(self) -> Graph | None:
Comment thread
justinchuby marked this conversation as resolved.
"""Return the graph that defines this value.

When the value is an input/output/initializer of a graph, the owning graph
is that graph. When the value is an output of a node, the owning graph is the
graph that the node belongs to. When the value is not owned by any graph,
it returns ``None``.
"""
if self._graph is not None:
return self._graph
if self._producer is not None:
return self._producer.graph

Check warning on line 1873 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1873

Added line #L1873 was not covered by tests
return None

def _owned_by_graph(self) -> bool:
"""Return True if the value is owned by a graph."""
result = self._is_graph_input or self._is_graph_output or self._is_initializer
if result:
assert self._graph is not None
return result

def producer(self) -> Node | None:
"""The node that produces this value.

When producer is ``None``, the value does not belong to a node, and is
typically a graph input or an initializer.
typically a graph input or an initializer. You can use :meth:`graph``
to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
"""
return self._producer

Expand Down Expand Up @@ -1986,15 +2022,17 @@
self._metadata_props = {}
return self._metadata_props

def is_graph_input(self) -> bool:
"""Whether the value is an input of a graph."""
return self._is_graph_input

def is_graph_output(self) -> bool:
"""Whether the value is an output of a graph."""
if (producer := self.producer()) is None:
return False
if (graph := producer.graph) is None:
return False
# Cannot use `in` because __eq__ may be defined by subclasses, even though
# it is not recommended
return any(output is self for output in graph.outputs)
return self._is_graph_output

def is_initializer(self) -> bool:
"""Whether the value is an initializer of a graph."""
return self._is_initializer


def Input(
Expand Down Expand Up @@ -2104,9 +2142,9 @@
self.name = name

# Private fields that are not to be accessed by any other classes
self._inputs = list(inputs)
self._outputs = list(outputs)
self._initializers = {}
self._inputs = _graph_containers.GraphInputs(self, inputs)
self._outputs = _graph_containers.GraphOutputs(self, outputs)
self._initializers = _graph_containers.GraphInitializers(self)
for initializer in initializers:
if isinstance(initializer, str):
raise TypeError(
Expand All @@ -2131,15 +2169,15 @@
self.extend(nodes)

@property
def inputs(self) -> list[Value]:
def inputs(self) -> MutableSequence[Value]:
return self._inputs

@property
def outputs(self) -> list[Value]:
def outputs(self) -> MutableSequence[Value]:
return self._outputs

@property
def initializers(self) -> dict[str, Value]:
def initializers(self) -> MutableMapping[str, Value]:
return self._initializers

def register_initializer(self, value: Value) -> None:
Expand All @@ -2159,15 +2197,15 @@
ValueError: If the initializer is produced by a node.
ValueError: If the value does not have its ``.const_value`` set.
"""
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.name in self._initializers:
if self._initializers[value.name] is not value:
raise ValueError(
f"Initializer '{value.name}' is already registered, but"
" it is not the same object: existing={self._initializers[value.name]!r},"
f" new={value!r}"
)
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.producer() is not None:
raise ValueError(
f"Value '{value!r}' is produced by a node and cannot be an initializer."
Expand Down Expand Up @@ -2858,11 +2896,11 @@
self._overload = value

@property
def inputs(self) -> list[Value]:
def inputs(self) -> MutableSequence[Value]:
return self._graph.inputs

@property
def outputs(self) -> list[Value]:
def outputs(self) -> MutableSequence[Value]:
return self._graph.outputs

@property
Expand Down
Loading
Loading