Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f4da452
tracked list
justinchuby May 7, 2025
e64efe1
[IR] Record owning graph for input/output and initializers
justinchuby May 7, 2025
d95d276
GraphOutputs
justinchuby May 7, 2025
89ef173
format
justinchuby May 7, 2025
3c859c9
no init
justinchuby May 7, 2025
eaf0ca6
owning_graph
justinchuby May 7, 2025
510d0b9
quote the type
justinchuby May 7, 2025
a5ac719
# pylint: disable=protected-access
justinchuby May 7, 2025
24c7a42
core
justinchuby May 7, 2025
f626f07
init
justinchuby May 7, 2025
e0e6f0a
Update onnxscript/ir/_core.py
justinchuby May 7, 2025
e80de25
GraphInitializers
justinchuby May 8, 2025
b6a0fe0
owning_graph
justinchuby May 8, 2025
4b48d0d
docs
justinchuby May 8, 2025
847b48c
quote
justinchuby May 8, 2025
8e72931
syntax
justinchuby May 8, 2025
6cf7883
Rename
justinchuby May 8, 2025
41db1b2
Rename to graph to match node
justinchuby May 8, 2025
078074e
wip tests
justinchuby May 8, 2025
45898a3
Fix graph
justinchuby May 8, 2025
f1b330c
test
justinchuby May 8, 2025
6c76fb3
Check
justinchuby May 8, 2025
a4e2fc7
More tests
justinchuby May 8, 2025
751db58
wip
justinchuby May 8, 2025
66dfdb2
Data structures
justinchuby May 8, 2025
6431711
tests
justinchuby May 8, 2025
8a1635d
Apply suggestions from code review
justinchuby May 8, 2025
1108dc0
logger
justinchuby May 8, 2025
e6aa051
Fix if
justinchuby May 8, 2025
cb226dd
Fix tests
justinchuby May 8, 2025
91992b0
logger
justinchuby May 8, 2025
3739ded
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby May 8, 2025
f467daf
Fix __getitem__
justinchuby May 8, 2025
9397c46
Use booleans
justinchuby May 8, 2025
4c3afc8
test
justinchuby May 8, 2025
42d678c
ref counter
justinchuby May 8, 2025
0933963
RuntimeError
justinchuby May 8, 2025
22096b4
test
justinchuby May 8, 2025
2f62c50
Fix constant lifting
justinchuby May 8, 2025
de6ad6f
Update onnxscript/ir/_graph_containers.py
justinchuby May 8, 2025
added12
Fix test
justinchuby May 9, 2025
dc0b8e2
typing
justinchuby May 9, 2025
6964109
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
Returns:
A dictionary mapping names to values.
"""
values = {}
values: dict[str, _core.Value] = {}
values.update(graph.initializers)
# The names of the values can be None or "", which we need to exclude
for input in graph.inputs:
Expand Down
53 changes: 27 additions & 26 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import contextlib
import dataclasses
import heapq
import logging
import math
import mmap
import os
Expand All @@ -32,6 +31,7 @@
Generic,
Iterable,
Iterator,
MutableMapping,
MutableSequence,
NamedTuple,
OrderedDict,
Expand Down Expand Up @@ -82,9 +82,6 @@
)


logger = logging.getLogger(__name__)


def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
"""Use this function to check if an object is compatible with numpy.

Expand Down Expand Up @@ -1760,10 +1757,11 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):

__slots__ = (
"_const_value",
"_graph_initializer_of",
"_graph_input_of",
"_graph_output_of",
"_graph",
"_index",
"_is_graph_input",
"_is_graph_output",
"_is_initializer",
"_metadata",
"_metadata_props",
"_name",
Expand Down Expand Up @@ -1816,10 +1814,11 @@ def __init__(

# The graph this value belongs to. It is set *only* when the value is added as
# a graph input, output or initializer.
# The two properties can only be set by the Graph class (GraphIO).
self._graph_input_of: Graph | None = None
self._graph_output_of: Graph | None = None
self._graph_initializer_of: Graph | None = None
# The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
self._graph: Graph | None = None
self._is_graph_input: bool = False
self._is_graph_output: bool = False
self._is_initializer: bool = False

def __repr__(self) -> str:
value_name = self.name if self.name else "anonymous:" + str(id(self))
Expand Down Expand Up @@ -1868,17 +1867,19 @@ def graph(self) -> Graph | None:
graph that the node belongs to. When the value is not owned by any graph,
it returns ``None``.
"""
if self._graph_initializer_of is not None:
return self._graph_initializer_of
if self._graph_input_of is not None:
return self._graph_input_of
if self._graph_output_of is not None:
return self._graph_output_of

if self._graph is not None:
return self._graph
if self._producer is not None:
return self._producer.graph
return None

def _owned_by_graph(self) -> bool:
"""Return True if the value is owned by a graph."""
result = self._is_graph_input or self._is_graph_output or self._is_initializer
if result:
assert self._graph is not None
return result

def producer(self) -> Node | None:
"""The node that produces this value.

Expand Down Expand Up @@ -2023,15 +2024,15 @@ def metadata_props(self) -> dict[str, str]:

def is_graph_input(self) -> bool:
"""Whether the value is an input of a graph."""
return self._graph_input_of is not None
return self._is_graph_input

def is_graph_output(self) -> bool:
"""Whether the value is an output of a graph."""
return self._graph_output_of is not None
return self._is_graph_output

def is_initializer(self) -> bool:
"""Whether the value is an initializer of a graph."""
return self._graph_initializer_of is not None
return self._is_initializer


def Input(
Expand Down Expand Up @@ -2176,7 +2177,7 @@ def outputs(self) -> MutableSequence[Value]:
return self._outputs

@property
def initializers(self) -> dict[str, Value]:
def initializers(self) -> MutableMapping[str, Value]:
return self._initializers

def register_initializer(self, value: Value) -> None:
Expand All @@ -2196,15 +2197,15 @@ def register_initializer(self, value: Value) -> None:
ValueError: If the initializer is produced by a node.
ValueError: If the value does not have its ``.const_value`` set.
"""
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.name in self._initializers:
if self._initializers[value.name] is not value:
raise ValueError(
f"Initializer '{value.name}' is already registered, but"
" it is not the same object: existing={self._initializers[value.name]!r},"
f" new={value!r}"
)
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.producer() is not None:
raise ValueError(
f"Value '{value!r}' is produced by a node and cannot be an initializer."
Expand Down Expand Up @@ -2895,11 +2896,11 @@ def overload(self, value: str) -> None:
self._overload = value

@property
def inputs(self) -> list[Value]:
def inputs(self) -> MutableSequence[Value]:
return self._graph.inputs

@property
def outputs(self) -> list[Value]:
def outputs(self) -> MutableSequence[Value]:
return self._graph.outputs

@property
Expand Down
88 changes: 87 additions & 1 deletion onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ def test_append_to_inputs(self):
def test_append_input_raises_when_input_belongs_to_another_graph(self):
other_graph = _core.Graph(inputs=(), outputs=(), nodes=())
other_graph.inputs.append(self.value1)
with self.assertRaisesRegex(ValueError, "is already an input of a different graph"):
with self.assertRaisesRegex(ValueError, "is already owned by a different graph"):
self.graph.inputs.append(self.value1)
# Append is ok after the value is removed from the old graph
other_graph.inputs.clear()
Expand All @@ -1223,6 +1223,14 @@ def test_pop_from_inputs(self):
self.assertFalse(self.value1.is_graph_input())
self.assertIsNone(self.value1.graph)

def test_pop_from_duplicated_inputs(self):
self.graph.inputs.extend([self.value1, self.value1])
popped = self.graph.inputs.pop()
self.assertIs(popped, self.value1)
self.assertIn(self.value1, self.graph.inputs)
self.assertTrue(self.value1.is_graph_input())
self.assertIs(self.value1.graph, self.graph)

def test_pop_from_inputs_raises_when_empty(self):
with self.assertRaises(IndexError):
self.graph.inputs.pop()
Expand All @@ -1249,6 +1257,13 @@ def test_clear_inputs(self):
self.assertFalse(self.value2.is_graph_input())
self.assertIsNone(self.value2.graph)

def test_clear_duplicated_inputs(self):
self.graph.inputs.extend([self.value1, self.value1])
self.graph.inputs.clear()
self.assertEqual(len(self.graph.inputs), 0)
self.assertFalse(self.value1.is_graph_input())
self.assertIsNone(self.value1.graph)

def test_inputs_set_items(self):
self.graph.inputs.append(self.value1)
self.graph.inputs[-1] = self.value2
Expand All @@ -1260,6 +1275,34 @@ def test_inputs_set_items(self):
self.assertFalse(self.value1.is_graph_input())
self.assertIsNone(self.value1.graph)

def test_inputs_set_items_slices(self):
self.graph.inputs.extend([self.value1, self.value2])
# Replace with one existing and one new input
self.graph.inputs[0:2] = [self.value2, self.value3]
self.assertNotIn(self.value1, self.graph.inputs)
self.assertIn(self.value2, self.graph.inputs)
self.assertIn(self.value3, self.graph.inputs)
self.assertIs(self.value2.graph, self.graph)
self.assertIs(self.value3.graph, self.graph)
self.assertTrue(self.value2.is_graph_input())
self.assertTrue(self.value3.is_graph_input())
self.assertFalse(self.value1.is_graph_input())
self.assertIsNone(self.value1.graph)

def test_take_inputs(self):
self.graph.inputs.extend([self.value1, self.value2, self.value3])
inputs = self.graph.inputs[:2]
self.graph.inputs.clear()
self.graph.inputs.extend(inputs)
self.assertEqual(len(self.graph.inputs), 2)
self.assertEqual(self.graph.inputs, [self.value1, self.value2])
self.assertTrue(self.value1.is_graph_input())
self.assertTrue(self.value2.is_graph_input())
self.assertFalse(self.value3.is_graph_input())
self.assertIs(self.value1.graph, self.graph)
self.assertIs(self.value2.graph, self.graph)
self.assertIsNone(self.value3.graph)

def test_append_to_outputs(self):
self.graph.outputs.append(self.value2)
self.assertIn(self.value2, self.graph.outputs)
Expand Down Expand Up @@ -1289,6 +1332,14 @@ def test_pop_from_outputs(self):
self.assertFalse(self.value2.is_graph_output())
self.assertIsNone(self.value2.graph)

def test_pop_from_duplicated_outputs(self):
self.graph.outputs.extend([self.value1, self.value1])
popped = self.graph.outputs.pop()
self.assertIs(popped, self.value1)
self.assertIn(self.value1, self.graph.outputs)
self.assertTrue(self.value1.is_graph_output())
self.assertIs(self.value1.graph, self.graph)

def test_pop_from_outputs_raises_when_empty(self):
with self.assertRaises(IndexError):
self.graph.outputs.pop()
Expand All @@ -1315,6 +1366,13 @@ def test_clear_outputs(self):
self.assertFalse(self.value2.is_graph_output())
self.assertIsNone(self.value2.graph)

def test_clear_duplicated_outputs(self):
self.graph.outputs.extend([self.value1, self.value1])
self.graph.outputs.clear()
self.assertEqual(len(self.graph.outputs), 0)
self.assertFalse(self.value1.is_graph_output())
self.assertIsNone(self.value1.graph)

def test_outputs_set_items(self):
self.graph.outputs.append(self.value1)
self.graph.outputs[-1] = self.value2
Expand All @@ -1326,6 +1384,34 @@ def test_outputs_set_items(self):
self.assertFalse(self.value1.is_graph_output())
self.assertIsNone(self.value1.graph)

def test_outputs_set_items_slices(self):
self.graph.outputs.extend([self.value1, self.value2])
# Replace with one existing and one new output
self.graph.outputs[0:2] = [self.value2, self.value3]
self.assertNotIn(self.value1, self.graph.outputs)
self.assertIn(self.value2, self.graph.outputs)
self.assertIn(self.value3, self.graph.outputs)
self.assertIs(self.value2.graph, self.graph)
self.assertIs(self.value3.graph, self.graph)
self.assertTrue(self.value2.is_graph_output())
self.assertTrue(self.value3.is_graph_output())
self.assertFalse(self.value1.is_graph_output())
self.assertIsNone(self.value1.graph)

def test_take_outputs(self):
self.graph.outputs.extend([self.value1, self.value2, self.value3])
outputs = self.graph.outputs[:2]
self.graph.outputs.clear()
self.graph.outputs.extend(outputs)
self.assertEqual(len(self.graph.outputs), 2)
self.assertEqual(self.graph.outputs, [self.value1, self.value2])
self.assertTrue(self.value1.is_graph_output())
self.assertTrue(self.value2.is_graph_output())
self.assertFalse(self.value3.is_graph_output())
self.assertIs(self.value1.graph, self.graph)
self.assertIs(self.value2.graph, self.graph)
self.assertIsNone(self.value3.graph)

def test_set_initializers(self):
self.graph.initializers["initializer1"] = self.value3
self.assertIn("initializer1", self.graph.initializers)
Expand Down
Loading
Loading