-
Notifications
You must be signed in to change notification settings - Fork 109
[IR] Record owning graph for input/output/initializers #2282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
f4da452
tracked list
justinchuby e64efe1
[IR] Record owning graph for input/output and initializers
justinchuby d95d276
GraphOutputs
justinchuby 89ef173
format
justinchuby 3c859c9
no init
justinchuby eaf0ca6
owning_graph
justinchuby 510d0b9
quote the type
justinchuby a5ac719
# pylint: disable=protected-access
justinchuby 24c7a42
core
justinchuby f626f07
init
justinchuby e0e6f0a
Update onnxscript/ir/_core.py
justinchuby e80de25
GraphInitializers
justinchuby b6a0fe0
owning_graph
justinchuby 4b48d0d
docs
justinchuby 847b48c
quote
justinchuby 8e72931
syntax
justinchuby 6cf7883
Rename
justinchuby 41db1b2
Rename to graph to match node
justinchuby 078074e
wip tests
justinchuby 45898a3
Fix graph
justinchuby f1b330c
test
justinchuby 6c76fb3
Check
justinchuby a4e2fc7
More tests
justinchuby 751db58
wip
justinchuby 66dfdb2
Data structures
justinchuby 6431711
tests
justinchuby 8a1635d
Apply suggestions from code review
justinchuby 1108dc0
logger
justinchuby e6aa051
Fix if
justinchuby cb226dd
Fix tests
justinchuby 91992b0
logger
justinchuby 3739ded
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby f467daf
Fix __getitem__
justinchuby 9397c46
Use booleans
justinchuby 4c3afc8
test
justinchuby 42d678c
ref counter
justinchuby 0933963
RuntimeError
justinchuby 22096b4
test
justinchuby 2f62c50
Fix constant lifting
justinchuby de6ad6f
Update onnxscript/ir/_graph_containers.py
justinchuby added12
Fix test
justinchuby dc0b8e2
typing
justinchuby 6964109
Merge branch 'main' into justinchu/tracked-lists-2
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| """Tracked containers for graph.""" | ||
|
|
||
| # pylint: disable=protected-access | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| __all__ = [ | ||
| "GraphInputs", | ||
| "GraphOutputs", | ||
| ] | ||
|
|
||
| import collections | ||
| import logging | ||
| from typing import TYPE_CHECKING, Iterable, SupportsIndex | ||
|
|
||
| import onnxscript | ||
|
|
||
| if TYPE_CHECKING: | ||
| from onnxscript.ir import _core | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class _GraphIO(collections.UserList["_core.Value"]): | ||
| """The inputs and outputs of a Graph.""" | ||
|
|
||
| def __init__(self, graph: _core.Graph, initlist=None): | ||
| super().__init__(initlist) | ||
| self._graph = graph | ||
|
|
||
| def _check_invariance(self) -> None: | ||
| """Check the invariance of the graph.""" | ||
| raise NotImplementedError | ||
|
|
||
| def _set_graph(self, value: _core.Value) -> None: | ||
| """Set the graph for the value.""" | ||
| raise NotImplementedError | ||
|
|
||
| def _unset_graph(self, value: _core.Value) -> None: | ||
| """Unset the graph for the value.""" | ||
| raise NotImplementedError | ||
|
|
||
| def append(self, item: _core.Value) -> None: | ||
| """Add a new input to the graph.""" | ||
| super().append(item) | ||
| self._set_graph(item) | ||
| self._check_invariance() | ||
|
|
||
| def extend(self, other) -> None: | ||
| """Extend the list of inputs or outputs.""" | ||
| super().extend(other) | ||
| for item in other: | ||
| self._set_graph(item) | ||
|
|
||
| def insert(self, i: int, item: _core.Value) -> None: | ||
| """Insert an input/output to the graph.""" | ||
| super().insert(i, item) | ||
| self._set_graph(item) | ||
| self._check_invariance() | ||
|
|
||
| def pop(self, i: int = -1) -> _core.Value: | ||
| """Remove an input/output from the graph.""" | ||
| value = super().pop(i) | ||
| self._unset_graph(value) | ||
| self._check_invariance() | ||
| return value | ||
|
|
||
| def remove(self, item: _core.Value) -> None: | ||
| """Remove an input/output from the graph.""" | ||
| super().remove(item) | ||
| self._unset_graph(item) | ||
| self._check_invariance() | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear the list.""" | ||
| for value in self.data: | ||
| self._unset_graph(value) | ||
| super().clear() | ||
|
|
||
| def __setitem__(self, i, item) -> None: | ||
|
justinchuby marked this conversation as resolved.
Outdated
|
||
| """Replace an input/output to the node.""" | ||
| if isinstance(item, Iterable) and isinstance(i, slice): | ||
| # Modify a slice of the list | ||
| for value in self.data[i]: | ||
| self._unset_graph(value) | ||
| for value in item: | ||
| self._set_graph(value) | ||
| super().__setitem__(i, item) | ||
| self._check_invariance() | ||
| return | ||
| elif isinstance(i, SupportsIndex): | ||
| # Replace a single item | ||
| self._unset_graph(self.data[i]) | ||
| self._set_graph(item) | ||
| super().__setitem__(i, item) | ||
| self._check_invariance() | ||
| return | ||
|
|
||
| raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}") | ||
|
|
||
|
|
||
| class GraphInputs(_GraphIO): | ||
| """The inputs of a Graph.""" | ||
|
|
||
| def _check_invariance(self) -> None: | ||
| """Check the invariance of the graph.""" | ||
| if not onnxscript.DEBUG: | ||
| return | ||
| for value in self.data: | ||
| if value._graph_input_of is self._graph: | ||
|
|
||
| continue | ||
| raise ValueError( | ||
| f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}" | ||
| ) | ||
|
|
||
| def _set_graph(self, value: _core.Value) -> None: | ||
| """Set the graph for the value.""" | ||
| if value._graph_input_of is not None and value._graph_input_of is not self._graph: | ||
|
|
||
| logger.warning( | ||
| "Value '%s' is already an input of a different graph. Overwriting", | ||
| value, | ||
| ) | ||
| value._graph_input_of = self._graph | ||
|
|
||
|
|
||
| def _unset_graph(self, value: _core.Value) -> None: | ||
| """Unset the graph for the value.""" | ||
| if value._graph_input_of is not self._graph: | ||
|
|
||
| # The value is already added to a different graph | ||
| return | ||
| value._graph_input_of = None | ||
|
|
||
|
|
||
|
|
||
| class GraphOutputs(_GraphIO): | ||
| """The outputs of a Graph.""" | ||
|
|
||
| def _check_invariance(self) -> None: | ||
| """Check the invariance of the graph.""" | ||
| if not onnxscript.DEBUG: | ||
| return | ||
| for value in self.data: | ||
| if value._graph_output_of is self._graph: | ||
|
|
||
| continue | ||
| raise ValueError( | ||
| f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}" | ||
| ) | ||
|
|
||
| def _set_graph(self, value: _core.Value) -> None: | ||
| """Set the graph for the value.""" | ||
| if value._graph_output_of is not None and value._graph_output_of is not self._graph: | ||
|
|
||
| logger.warning( | ||
| "Value '%s' is already an output of a different graph. Overwriting", | ||
| value, | ||
| ) | ||
| value._graph_output_of = self._graph | ||
|
|
||
|
|
||
| def _unset_graph(self, value: _core.Value) -> None: | ||
| """Unset the graph for the value.""" | ||
| if value._graph_output_of is not self._graph: | ||
|
|
||
| # The value is already added to a different graph | ||
| return | ||
| value._graph_output_of = None | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.