forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathunused_removal.py
More file actions
204 lines (171 loc) · 7.84 KB
/
unused_removal.py
File metadata and controls
204 lines (171 loc) · 7.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
__all__ = [
"RemoveUnusedNodesPass",
"RemoveUnusedFunctionsPass",
"RemoveUnusedOpsetsPass",
]
import logging
import onnx
from onnxscript import ir
logger = logging.getLogger(__name__)
def _remove_unused_optional_outputs(
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
) -> None:
try:
if node.domain not in {"", "onnx.ai"}:
return
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
except Exception: # pylint: disable=broad-exception-caught
logger.info(
"Failed to get schema for %s, skipping optional output removal",
node,
stack_info=True,
)
return
if node.op_type == "BatchNormalization":
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
# If running_mean and running_var are not used, remove them, and the training_mode attribute
def is_used_output(i: int) -> bool:
if i < len(node.outputs):
val = node.outputs[i]
return val in graph_outputs or bool(val.uses())
return False
if is_used_output(1) or is_used_output(2):
return
if len(node.outputs) > 1:
node.outputs[1].name = ""
if len(node.outputs) > 2:
node.outputs[2].name = ""
node.attributes.pop("training_mode", None)
return
optional_info = []
for o in op_schema.outputs:
# Current ops do not have optional outputs if they have variable number of outputs
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
return
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
# If no optional outputs in spec, skip delete operations
if len([o == 1 for o in optional_info]) == 0:
return
for i, out in enumerate(node.outputs):
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
out.name = ""
def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
graph_outputs = frozenset(function_or_graph.outputs)
onnx_opset_version = function_or_graph.opset_imports.get("", None)
count = 0
for node in reversed(function_or_graph):
removable = True
for output in node.outputs:
if output in graph_outputs or output.uses():
removable = False
break
if removable:
function_or_graph.remove(node, safe=True)
count += 1
else:
if onnx_opset_version is not None:
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
for attr in node.attributes.values():
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
count += _remove_unused_nodes_in_graph_like(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
count += _remove_unused_nodes_in_graph_like(graph)
return count
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
def __init__(self, remove_initialized_inputs: bool =True ):
"""
:param remove_initialized_inputs: if `True` (default) remove unused inputs, in case
where is corresponding initializer, (those are typically rather initializers than inputs)
if changed to `False`, unused inputs remain, even if it has default initializer
Note: usual inputs will remain anyhow
"""
super().__init__()
self.remove_initialized_inputs = remove_initialized_inputs
def call(self, model: ir.Model) -> ir.passes.PassResult:
count = _remove_unused_nodes_in_graph_like(model.graph)
graph_outputs = frozenset(model.graph.outputs)
initializers = model.graph.initializers
if self.remove_initialized_inputs:
graph_inputs = model.graph.inputs
for i, input in reversed(list(enumerate(graph_inputs))):
if input.name in initializers and not (input in graph_outputs or input.uses()):
del graph_inputs[i]
count += 1
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
assert init.name is not None
del initializers[init.name]
count += 1
for function in model.functions.values():
count += _remove_unused_nodes_in_graph_like(function)
if count:
logger.info("Removed %s unused nodes", count)
return ir.passes.PassResult(model, modified=bool(count))
class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
def __init__(self):
super().__init__()
self._used: set[ir.OperatorIdentifier] | None = None
def call(self, model: ir.Model) -> ir.passes.PassResult:
self._used = set()
for node in ir.traversal.RecursiveGraphIterator(model.graph):
self._call_node(model, node)
# Update the model to remove unused functions
unused = set(model.functions) - self._used
if not unused:
logger.info("No unused functions to remove")
return ir.passes.PassResult(model, modified=False)
for op_identifier in unused:
del model.functions[op_identifier]
logger.info("Removed %s unused functions", len(unused))
logger.debug("Functions left: %s", list(model.functions))
logger.debug("Functions removed: %s", unused)
self._used = None
return ir.passes.PassResult(model, modified=bool(unused))
def _call_function(self, model: ir.Model, function: ir.Function) -> None:
assert self._used is not None
if function.identifier() in self._used:
# The function and its nodes are already recorded as used
return
self._used.add(function.identifier())
for node in ir.traversal.RecursiveGraphIterator(function):
self._call_node(model, node)
def _call_node(self, model: ir.Model, node: ir.Node) -> None:
op_identifier = node.op_identifier()
if op_identifier not in model.functions:
return
self._call_function(model, model.functions[op_identifier])
class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
"""Remove unused opset imports from the model and functions.
Attributes:
process_functions: Whether to process functions in the model. If True, the pass will
remove unused opset imports from functions as well. If False, only the main graph
will be processed.
"""
def __init__(self, process_functions: bool = True):
super().__init__()
self.process_functions = process_functions
def _process_graph_like(
self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
) -> bool:
for node in ir.traversal.RecursiveGraphIterator(graph_like):
used_domains.add(node.domain)
unused = set(graph_like.opset_imports) - used_domains
for domain in unused:
del graph_like.opset_imports[domain]
return bool(unused)
def call(self, model: ir.Model) -> ir.passes.PassResult:
# Record domains of all functions
used_domains = set()
for function in model.functions.values():
used_domains.add(function.domain)
modified = self._process_graph_like(model.graph, used_domains=used_domains)
if self.process_functions:
for function in model.functions.values():
modified |= self._process_graph_like(function, used_domains=set())
return ir.passes.PassResult(model, modified=modified)