11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
3- """Tracked lists for graph and node IO ."""
3+ """Tracked containers for graph."""
44
55from __future__ import annotations
66
7+ __all__ = [
8+ "GraphInputs" ,
9+ "GraphOutputs" ,
10+ ]
11+
712import collections
8- from typing import TYPE_CHECKING , Iterable , Literal , SupportsIndex
13+ from typing import TYPE_CHECKING , Iterable , SupportsIndex
14+
15+ import onnxscript
916
1017if TYPE_CHECKING :
1118 from onnxscript .ir import _core
1219
1320
14- class GraphIO (collections .UserList [_core .Value ]):
21+ class _GraphIO (collections .UserList [_core .Value ]):
1522 """The inputs and outputs of a Graph."""
1623
17- def __init__ (self , graph : _core .Graph , typ : Literal [ "input" , "output" ], initlist = None ):
24+ def __init__ (self , graph : _core .Graph , initlist = None ):
1825 super ().__init__ (initlist )
1926 self ._graph = graph
20- assert typ in {"intput" , "output" }
21- self ._typ = typ
27+
28+ def _check_invariance (self ) -> None :
29+ """Check the invariance of the graph."""
30+ raise NotImplementedError
2231
2332 def _set_graph (self , value : _core .Value ) -> None :
2433 """Set the graph for the value."""
25- if value ._graph_input_of is not None and value ._graph_input_of is not self ._graph :
26- raise ValueError (
27- f"Value '{ value } ' is already an input of a different graph: { value ._graph_input_of !r} "
28- )
29- if value ._graph_output_of is not None and value ._graph_output_of is not self ._graph :
30- raise ValueError (
31- f"Value '{ value } ' is already an output of a different graph: { value ._graph_output_of !r} "
32- )
33-
34- if self ._typ == "input" :
35- value ._graph_input_of = self ._graph
36- else :
37- value ._graph_output_of = self ._graph
34+ raise NotImplementedError
3835
3936 def _unset_graph (self , value : _core .Value ) -> None :
4037 """Unset the graph for the value."""
41- if self ._typ == "input" :
42- value ._graph_input_of = None
43- else :
44- value ._graph_output_of = None
38+ raise NotImplementedError
4539
4640 def append (self , item : _core .Value ) -> None :
4741 """Add a new input to the graph."""
4842 super ().append (item )
4943 self ._set_graph (item )
44+ self ._check_invariance ()
5045
5146 def extend (self , other ) -> None :
5247 """Extend the list of inputs or outputs."""
@@ -58,17 +53,20 @@ def insert(self, i: int, item: _core.Value) -> None:
5853 """Insert an input/output to the graph."""
5954 super ().insert (i , item )
6055 self ._set_graph (item )
56+ self ._check_invariance ()
6157
6258 def pop (self , i : int = - 1 ) -> _core .Value :
6359 """Remove an input/output from the graph."""
6460 value = super ().pop (i )
6561 self ._unset_graph (value )
62+ self ._check_invariance ()
6663 return value
6764
6865 def remove (self , item : _core .Value ) -> None :
6966 """Remove an input/output from the graph."""
7067 super ().remove (item )
7168 self ._unset_graph (item )
69+ self ._check_invariance ()
7270
7371 def clear (self ) -> None :
7472 """Clear the list."""
@@ -85,12 +83,76 @@ def __setitem__(self, i, item) -> None:
8583 for value in item :
8684 self ._set_graph (value )
8785 super ().__setitem__ (i , item )
86+ self ._check_invariance ()
8887 return
8988 elif isinstance (item , _core .Value ) and isinstance (i , SupportsIndex ):
9089 # Replace a single item
9190 self ._unset_graph (self .data [i ])
9291 self ._set_graph (item )
9392 super ().__setitem__ (i , item )
93+ self ._check_invariance ()
9494 return
9595
9696 raise TypeError (f"Invalid types for __setitem__: { type (i )} and { type (item )} " )
97+
98+
99+ class GraphInputs (_GraphIO ):
100+ """The inputs of a Graph."""
101+
102+ def __init__ (self , graph : _core .Graph , initlist = None ):
103+ super ().__init__ (graph , initlist )
104+
105+ def _check_invariance (self ) -> None :
106+ """Check the invariance of the graph."""
107+ if not onnxscript .DEBUG :
108+ return
109+ for value in self .data :
110+ if value ._graph_input_of is self ._graph :
111+ continue
112+ raise ValueError (
113+ f"Invariance error: Value '{ value } ' is not an input of the graph: { self ._graph !r} "
114+ )
115+
116+ def _set_graph (self , value : _core .Value ) -> None :
117+ """Set the graph for the value."""
118+ if value ._graph_input_of is not None and value ._graph_input_of is not self ._graph :
119+ raise ValueError (
120+ f"Value '{ value } ' is already an input of a different graph: { value ._graph_input_of !r} "
121+ )
122+
123+ value ._graph_input_of = self ._graph
124+
125+ def _unset_graph (self , value : _core .Value ) -> None :
126+ """Unset the graph for the value."""
127+ value ._graph_input_of = None
128+
129+
130+ class GraphOutputs (_GraphIO ):
131+ """The outputs of a Graph."""
132+
133+ def __init__ (self , graph : _core .Graph , initlist = None ):
134+ super ().__init__ (graph , initlist )
135+
136+ def _check_invariance (self ) -> None :
137+ """Check the invariance of the graph."""
138+ if not onnxscript .DEBUG :
139+ return
140+ for value in self .data :
141+ if value ._graph_output_of is self ._graph :
142+ continue
143+ raise ValueError (
144+ f"Invariance error: Value '{ value } ' is not an output of the graph: { self ._graph !r} "
145+ )
146+
147+ def _set_graph (self , value : _core .Value ) -> None :
148+ """Set the graph for the value."""
149+ if value ._graph_output_of is not None and value ._graph_output_of is not self ._graph :
150+ raise ValueError (
151+ f"Value '{ value } ' is already an output of a different graph: { value ._graph_output_of !r} "
152+ )
153+
154+ value ._graph_output_of = self ._graph
155+
156+ def _unset_graph (self , value : _core .Value ) -> None :
157+ """Unset the graph for the value."""
158+ value ._graph_output_of = None
0 commit comments