-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy path_tape.py
More file actions
211 lines (185 loc) · 6.84 KB
/
_tape.py
File metadata and controls
211 lines (185 loc) · 6.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Convenience methods for constructing the IR."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import (
Any,
Optional,
)
import onnx_ir as ir
from onnx_ir import _convenience
# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = set[tuple[str, Optional[int]]]
class Tape:
"""Tape class.
A tape is a recorder that collects nodes and initializers that are created so
that they can be used for creating a graph.
Example::
import onnx_ir as ir
tape = ir.tape.Tape()
a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
b: ir.Value = ...
c: ir.Value = ...
x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
model = ir.Model(
graph := ir.Graph(
inputs=[b, c],
outputs=[y],
nodes=tape.nodes,
initializers=tape.initializers
opset_imports={"": 20},
),
ir_version=10,
)
Attributes:
graph_like: The graph to append the new nodes and initializers to. When
it is None, the nodes and initializers are creating without owned by a graph.
Initializers will not be added to functions because it is not supported by ONNX.
"""
def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None:
self._nodes: list[ir.Node] = []
self._initializers: list[ir.Value] = []
self._used_opsets: UsedOpsets = set()
self.graph_like = graph_like
def __repr__(self) -> str:
return f"Tape(nodes={self._nodes}, initializers={self._initializers})"
@property
def nodes(self) -> Sequence[ir.Node]:
return tuple(self._nodes)
@property
def initializers(self) -> Sequence[ir.Value]:
return tuple(self._initializers)
@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets
def op(
self,
op_type: str,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
*,
domain: str = "",
overload: str = "",
version: int | None = None,
graph: ir.Graph | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
output: ir.Value | None = None,
) -> ir.Value:
if attributes is None:
attrs: Sequence[ir.Attr] = ()
else:
attrs = _convenience.convert_attributes(attributes)
output_kwargs: dict[str, Any]
if output is None:
output_kwargs = dict(num_outputs=1)
else:
output_kwargs = dict(outputs=[output])
node = ir.Node(
domain,
op_type,
inputs,
attributes=attrs,
**output_kwargs,
overload=overload,
version=version,
graph=graph or self.graph_like,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
self._nodes.append(node)
self._used_opsets.add((domain, version))
return node.outputs[0]
def op_multi_out(
self,
op_type: str,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
*,
num_outputs: int | None = None,
outputs: Sequence[ir.Value] | None = None,
domain: str = "",
overload: str = "",
version: int | None = None,
graph: ir.Graph | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> Sequence[ir.Value]:
if num_outputs is None and outputs is None:
raise ValueError("Either num_outputs or outputs must be provided.")
if num_outputs is not None and outputs is not None:
raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.")
output_kwargs: dict[str, Any]
if outputs is None:
output_kwargs = dict(num_outputs=num_outputs)
else:
output_kwargs = dict(outputs=outputs)
if attributes is None:
attrs: Sequence[ir.Attr] = ()
else:
attrs = _convenience.convert_attributes(attributes)
node = ir.Node(
domain,
op_type,
inputs,
attributes=attrs,
**output_kwargs,
overload=overload,
version=version,
graph=graph or self.graph_like,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
self._nodes.append(node)
self._used_opsets.add((domain, version))
return node.outputs
def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
name = name or tensor.name
if name is None:
raise ValueError("Name must be provided for initializer.")
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
value = ir.Value(
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
)
self._initializers.append(value)
if isinstance(self.graph_like, ir.Graph):
self.graph_like.register_initializer(value)
return value
class Builder(Tape):
"""An extension of the tape that provides a more convenient API for constructing the IR."""
def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
assert isinstance(outputs, int)
num_outputs = outputs
if num_outputs == 1:
value = super().op(
op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version
)
if isinstance(outputs, Sequence):
value.name = outputs[0]
return value
values = super().op_multi_out(
op_type,
inputs=inputs,
attributes=kwargs,
domain=domain,
version=version,
num_outputs=num_outputs,
)
if isinstance(outputs, Sequence):
for value, name in zip(values, outputs):
value.name = name
return values