Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f4da452
tracked list
justinchuby May 7, 2025
e64efe1
[IR] Record owning graph for input/output and initializers
justinchuby May 7, 2025
d95d276
GraphOutputs
justinchuby May 7, 2025
89ef173
format
justinchuby May 7, 2025
3c859c9
no init
justinchuby May 7, 2025
eaf0ca6
owning_graph
justinchuby May 7, 2025
510d0b9
quote the type
justinchuby May 7, 2025
a5ac719
# pylint: disable=protected-access
justinchuby May 7, 2025
24c7a42
core
justinchuby May 7, 2025
f626f07
init
justinchuby May 7, 2025
e0e6f0a
Update onnxscript/ir/_core.py
justinchuby May 7, 2025
e80de25
GraphInitializers
justinchuby May 8, 2025
b6a0fe0
owning_graph
justinchuby May 8, 2025
4b48d0d
docs
justinchuby May 8, 2025
847b48c
quote
justinchuby May 8, 2025
8e72931
syntax
justinchuby May 8, 2025
6cf7883
Rename
justinchuby May 8, 2025
41db1b2
Rename to graph to match node
justinchuby May 8, 2025
078074e
wip tests
justinchuby May 8, 2025
45898a3
Fix graph
justinchuby May 8, 2025
f1b330c
test
justinchuby May 8, 2025
6c76fb3
Check
justinchuby May 8, 2025
a4e2fc7
More tests
justinchuby May 8, 2025
751db58
wip
justinchuby May 8, 2025
66dfdb2
Data structures
justinchuby May 8, 2025
6431711
tests
justinchuby May 8, 2025
8a1635d
Apply suggestions from code review
justinchuby May 8, 2025
1108dc0
logger
justinchuby May 8, 2025
e6aa051
Fix if
justinchuby May 8, 2025
cb226dd
Fix tests
justinchuby May 8, 2025
91992b0
logger
justinchuby May 8, 2025
3739ded
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby May 8, 2025
f467daf
Fix __getitem__
justinchuby May 8, 2025
9397c46
Use booleans
justinchuby May 8, 2025
4c3afc8
test
justinchuby May 8, 2025
42d678c
ref counter
justinchuby May 8, 2025
0933963
RuntimeError
justinchuby May 8, 2025
22096b4
test
justinchuby May 8, 2025
2f62c50
Fix constant lifting
justinchuby May 8, 2025
de6ad6f
Update onnxscript/ir/_graph_containers.py
justinchuby May 8, 2025
added12
Fix test
justinchuby May 9, 2025
dc0b8e2
typing
justinchuby May 9, 2025
6964109
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import dataclasses
import heapq
import logging
import math
import mmap
import os
Expand All @@ -31,6 +32,7 @@
Generic,
Iterable,
Iterator,
MutableSequence,
NamedTuple,
OrderedDict,
Sequence,
Expand All @@ -50,6 +52,7 @@
_metadata,
_name_authority,
_protocols,
_tracked_containers,
_type_casting,
)

Expand Down Expand Up @@ -79,6 +82,9 @@
)


logger = logging.getLogger(__name__)


def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
"""Use this function to check if an object is compatible with numpy.

Expand Down Expand Up @@ -1757,6 +1763,9 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):

__slots__ = (
"_const_value",
"_graph_initializer_of",
Comment thread
justinchuby marked this conversation as resolved.
Outdated
"_graph_input_of",
"_graph_output_of",
"_index",
"_metadata",
"_metadata_props",
Expand Down Expand Up @@ -1808,6 +1817,12 @@ def __init__(
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

# The graph this value belongs to. It is set *only* when the value is added as
# a graph input or a graph output.
# The two properties can only be set by the Graph class (GraphIO).
Comment thread
justinchuby marked this conversation as resolved.
Outdated
Comment thread
justinchuby marked this conversation as resolved.
Outdated
self._graph_input_of: Graph | None = None
self._graph_output_of: Graph | None = None

def __repr__(self) -> str:
value_name = self.name if self.name else "anonymous:" + str(id(self))
type_text = f", type={self.type!r}" if self.type is not None else ""
Expand Down Expand Up @@ -1846,12 +1861,27 @@ def _constant_tensor_part(self) -> str:
return f"{{{self.const_value.__class__.__name__}(...)}}"
return ""

def owning_graph(self) -> Graph | None:
"""Return the graph that defines this value when it is a graph input."""
if self._producer is not None and self._graph_input_of is not None:
logger.warning(
"The value is owned by a node but it is simultaneously a graph input. "
"The graph is invalid."
Comment thread
justinchuby marked this conversation as resolved.
Outdated
)
return self._graph_input_of

def producer(self) -> Node | None:
"""The node that produces this value.

When producer is ``None``, the value does not belong to a node, and is
typically a graph input or an initializer.
typically a graph input or an initializer, and should have ``owning_graph()``
set.
"""
if self._producer is not None and self._graph_input_of is not None:
logger.warning(
"The value is owned by a node but it is simultaneously a graph input. "
"The graph is invalid."
)
return self._producer

def consumers(self) -> Sequence[Node]:
Expand Down Expand Up @@ -1986,15 +2016,13 @@ def metadata_props(self) -> dict[str, str]:
self._metadata_props = {}
return self._metadata_props

def is_graph_input(self) -> bool:
"""Whether the value is an input of a graph."""
return self._graph_input_of is not None

def is_graph_output(self) -> bool:
"""Whether the value is an output of a graph."""
if (producer := self.producer()) is None:
return False
if (graph := producer.graph) is None:
return False
# Cannot use `in` because __eq__ may be defined by subclasses, even though
# it is not recommended
return any(output is self for output in graph.outputs)
return self._graph_output_of is not None


def Input(
Expand Down Expand Up @@ -2104,8 +2132,8 @@ def __init__(
self.name = name

# Private fields that are not to be accessed by any other classes
self._inputs = list(inputs)
self._outputs = list(outputs)
self._inputs = _tracked_containers.GraphInputs(self, inputs)
self._outputs = _tracked_containers.GraphOutputs(self, outputs)
self._initializers = {}
for initializer in initializers:
if isinstance(initializer, str):
Expand All @@ -2131,11 +2159,11 @@ def __init__(
self.extend(nodes)

@property
def inputs(self) -> list[Value]:
def inputs(self) -> MutableSequence[Value]:
return self._inputs

@property
def outputs(self) -> list[Value]:
def outputs(self) -> MutableSequence[Value]:
return self._outputs

@property
Expand Down
164 changes: 164 additions & 0 deletions onnxscript/ir/_tracked_containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tracked containers for graph."""

# pylint: disable=protected-access

from __future__ import annotations

__all__ = [
"GraphInputs",
"GraphOutputs",
]

import collections
import logging
from typing import TYPE_CHECKING, Iterable, SupportsIndex

import onnxscript

if TYPE_CHECKING:
from onnxscript.ir import _core


logger = logging.getLogger(__name__)


class _GraphIO(collections.UserList["_core.Value"]):
"""The inputs and outputs of a Graph."""

def __init__(self, graph: _core.Graph, initlist=None):
super().__init__(initlist)
self._graph = graph

def _check_invariance(self) -> None:
"""Check the invariance of the graph."""
raise NotImplementedError

def _set_graph(self, value: _core.Value) -> None:
"""Set the graph for the value."""
raise NotImplementedError

def _unset_graph(self, value: _core.Value) -> None:
"""Unset the graph for the value."""
raise NotImplementedError

def append(self, item: _core.Value) -> None:
"""Add a new input to the graph."""
super().append(item)
self._set_graph(item)
self._check_invariance()

def extend(self, other) -> None:
"""Extend the list of inputs or outputs."""
super().extend(other)
for item in other:
self._set_graph(item)

def insert(self, i: int, item: _core.Value) -> None:
"""Insert an input/output to the graph."""
super().insert(i, item)
self._set_graph(item)
self._check_invariance()

def pop(self, i: int = -1) -> _core.Value:
"""Remove an input/output from the graph."""
value = super().pop(i)
self._unset_graph(value)
self._check_invariance()
return value

def remove(self, item: _core.Value) -> None:
"""Remove an input/output from the graph."""
super().remove(item)
self._unset_graph(item)
self._check_invariance()

def clear(self) -> None:
"""Clear the list."""
for value in self.data:
self._unset_graph(value)
super().clear()

def __setitem__(self, i, item) -> None:
Comment thread
justinchuby marked this conversation as resolved.
Outdated
"""Replace an input/output to the node."""
if isinstance(item, Iterable) and isinstance(i, slice):
# Modify a slice of the list
for value in self.data[i]:
self._unset_graph(value)
for value in item:
self._set_graph(value)
super().__setitem__(i, item)
self._check_invariance()
return
elif isinstance(i, SupportsIndex):
# Replace a single item
self._unset_graph(self.data[i])
self._set_graph(item)
super().__setitem__(i, item)
self._check_invariance()
return

raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")


class GraphInputs(_GraphIO):
"""The inputs of a Graph."""

def _check_invariance(self) -> None:
"""Check the invariance of the graph."""
if not onnxscript.DEBUG:
return
for value in self.data:
if value._graph_input_of is self._graph:
Comment thread Fixed
continue
raise ValueError(
f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
)

def _set_graph(self, value: _core.Value) -> None:
"""Set the graph for the value."""
if value._graph_input_of is not None and value._graph_input_of is not self._graph:
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
logger.warning(
"Value '%s' is already an input of a different graph. Overwriting",
value,
)
value._graph_input_of = self._graph
Comment thread Fixed

def _unset_graph(self, value: _core.Value) -> None:
"""Unset the graph for the value."""
if value._graph_input_of is not self._graph:
Comment thread Fixed
# The value is already added to a different graph
return
value._graph_input_of = None
Comment thread Fixed


class GraphOutputs(_GraphIO):
"""The outputs of a Graph."""

def _check_invariance(self) -> None:
"""Check the invariance of the graph."""
if not onnxscript.DEBUG:
return
for value in self.data:
if value._graph_output_of is self._graph:
Comment thread Fixed
continue
raise ValueError(
f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
)

def _set_graph(self, value: _core.Value) -> None:
"""Set the graph for the value."""
if value._graph_output_of is not None and value._graph_output_of is not self._graph:
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
logger.warning(
"Value '%s' is already an output of a different graph. Overwriting",
value,
)
value._graph_output_of = self._graph
Comment thread Fixed

def _unset_graph(self, value: _core.Value) -> None:
"""Unset the graph for the value."""
if value._graph_output_of is not self._graph:
Comment thread Fixed
# The value is already added to a different graph
return
value._graph_output_of = None
Comment thread Fixed
Loading