22# Licensed under the MIT License.
33"""Convenience methods for constructing the IR."""
44
5- # NOTE: This is a temporary solution for constructing the IR. It should be replaced
6- # with a more permanent solution in the future.
7-
85from __future__ import annotations
96
10- from typing import Any , Iterable , Iterator , List , Mapping , Optional , Sequence , Tuple
7+ from typing import (
8+ Any ,
9+ Mapping ,
10+ Optional ,
11+ Sequence ,
12+ Tuple ,
13+ )
1114
1215from onnxscript import ir
1316from onnxscript .ir import _convenience
1417
18+ # A type representing the domains/versions used in creating nodes in IR.
19+ UsedOpsets = set [Tuple [str , Optional [int ]]]
20+
21+
22+ class Tape :
23+ """Tape class.
24+
25+ A tape is a recorder that collects nodes and initializers that are created so
26+ that they can be used for creating a graph.
27+
28+ Example::
29+ from onnxscript import ir
30+
31+ tape = ir.tape.Tape()
32+ a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
33+ b: ir.Value = ...
34+ c: ir.Value = ...
35+ x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
36+ y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
37+ model = ir.Model(
38+ graph := ir.Graph(
39+ inputs=[b, c],
40+ outputs=[y],
41+ nodes=tape.nodes,
42+ initializers=tape.initializers
43+ opset_imports={"": 20},
44+ ),
45+ ir_version=10,
46+ )
1547
16- class Tape (Iterable [ir .Node ]):
17- """A tape for recording nodes that are created."""
48+ Attributes:
49+ graph_like: The graph to append the new nodes and initializers to. When
50+ it is None, the nodes and initializers are creating without owned by a graph.
51+ Initializers will not be added to functions because it is not supported by ONNX.
52+ """
1853
19- def __init__ (self ) -> None :
54+ def __init__ (self , graph_like : ir . Graph | ir . Function | None = None ) -> None :
2055 self ._nodes : list [ir .Node ] = []
2156 self ._initializers : list [ir .Value ] = []
57+ self ._used_opsets : UsedOpsets = set ()
58+ self .graph_like = graph_like
2259
23- def __iter__ (self ) -> Iterator [ ir . Node ] :
24- return iter ( self ._nodes )
60+ def __repr__ (self ) -> str :
61+ return f"Tape(nodes= { self ._nodes } , initializers= { self . _initializers } )"
2562
2663 @property
2764 def nodes (self ) -> Sequence [ir .Node ]:
@@ -31,19 +68,43 @@ def nodes(self) -> Sequence[ir.Node]:
3168 def initializers (self ) -> Sequence [ir .Value ]:
3269 return tuple (self ._initializers )
3370
71+ @property
72+ def used_opsets (self ) -> UsedOpsets :
73+ return self ._used_opsets
74+
3475 def op (
3576 self ,
3677 op_type : str ,
3778 inputs : Sequence [ir .Value | None ],
3879 attributes : Mapping [str , _convenience .SupportedAttrTypes ] | None = None ,
80+ * ,
3981 domain : str = "" ,
82+ overload : str = "" ,
83+ version : int | None = None ,
84+ graph : ir .Graph | None = None ,
85+ name : str | None = None ,
86+ doc_string : str | None = None ,
87+ metadata_props : dict [str , str ] | None = None ,
4088 ) -> ir .Value :
4189 if attributes is None :
4290 attrs : Sequence [ir .Attr | ir .RefAttr ] = ()
4391 else :
4492 attrs = _convenience .convert_attributes (attributes )
45- node = ir .Node (domain , op_type , inputs , attributes = attrs , num_outputs = 1 )
93+ node = ir .Node (
94+ domain ,
95+ op_type ,
96+ inputs ,
97+ attributes = attrs ,
98+ num_outputs = 1 ,
99+ overload = overload ,
100+ version = version ,
101+ graph = graph or self .graph_like ,
102+ name = name ,
103+ doc_string = doc_string ,
104+ metadata_props = metadata_props ,
105+ )
46106 self ._nodes .append (node )
107+ self ._used_opsets .add ((domain , version ))
47108
48109 return node .outputs [0 ]
49110
@@ -55,13 +116,32 @@ def op_multi_output(
55116 * ,
56117 num_outputs : int ,
57118 domain : str = "" ,
119+ overload : str = "" ,
120+ version : int | None = None ,
121+ graph : ir .Graph | None = None ,
122+ name : str | None = None ,
123+ doc_string : str | None = None ,
124+ metadata_props : dict [str , str ] | None = None ,
58125 ) -> Sequence [ir .Value ]:
59126 if attributes is None :
60127 attrs : Sequence [ir .Attr | ir .RefAttr ] = ()
61128 else :
62129 attrs = _convenience .convert_attributes (attributes )
63- node = ir .Node (domain , op_type , inputs , attributes = attrs , num_outputs = num_outputs )
130+ node = ir .Node (
131+ domain ,
132+ op_type ,
133+ inputs ,
134+ attributes = attrs ,
135+ num_outputs = num_outputs ,
136+ overload = overload ,
137+ version = version ,
138+ graph = graph or self .graph_like ,
139+ name = name ,
140+ doc_string = doc_string ,
141+ metadata_props = metadata_props ,
142+ )
64143 self ._nodes .append (node )
144+ self ._used_opsets .add ((domain , version ))
65145
66146 return node .outputs
67147
@@ -74,20 +154,14 @@ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.
74154 name = name , shape = shape , type = ir .TensorType (tensor .dtype ), const_value = tensor
75155 )
76156 self ._initializers .append (value )
157+ if isinstance (self .graph_like , ir .Graph ):
158+ self .graph_like .register_initializer (value )
77159 return value
78160
79161
80- # A type representing the domains/versions used in creating nodes in IR.
81- UsedOpsets = List [Tuple [str , Optional [int ]]]
82-
83-
84162class Builder (Tape ):
85163 """An extension of the tape that provides a more convenient API for constructing the IR."""
86164
87- def __init__ (self ):
88- super ().__init__ ()
89- self ._used_opsets : UsedOpsets = []
90-
91165 def __getattr__ (self , op_type : str ) -> Any :
92166 return lambda * args , ** kwargs : self ._make_node (op_type , args , kwargs )
93167
@@ -101,20 +175,22 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str,
101175 assert isinstance (outputs , int )
102176 num_outputs = outputs
103177
104- self ._used_opsets .append ((domain , version ))
105178 if num_outputs == 1 :
106- value = super ().op (op_type , inputs = inputs , attributes = kwargs , domain = domain )
179+ value = super ().op (
180+ op_type , inputs = inputs , attributes = kwargs , domain = domain , version = version
181+ )
107182 if isinstance (outputs , Sequence ):
108183 value .name = outputs [0 ]
109184 return value
110185 values = super ().op_multi_output (
111- op_type , inputs = inputs , attributes = kwargs , domain = domain , num_outputs = num_outputs
186+ op_type ,
187+ inputs = inputs ,
188+ attributes = kwargs ,
189+ domain = domain ,
190+ version = version ,
191+ num_outputs = num_outputs ,
112192 )
113193 if isinstance (outputs , Sequence ):
114194 for value , name in zip (values , outputs ):
115195 value .name = name
116196 return values
117-
118- @property
119- def used_opsets (self ) -> UsedOpsets :
120- return self ._used_opsets
0 commit comments