Skip to content

Commit 4633a3a

Browse files
[Pass] Fix bugs in LiftConstantsToInitializersPass (#2189)
Fix #2184 (1) Fix the corner case when the constant is the graph output, we don't lift it. (2) Add an option to the pass controlling lifting all constants to initializers, or only "value". (following ort pass: https://github.com/microsoft/onnxruntime/blob/d7c688e15c1dc40f57140bff08c78e01a88b19fc/onnxruntime/python/tools/transformers/onnx_model.py#L525). Default to False, where we only lift "value". --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 6bdfcfd commit 4633a3a

File tree

2 files changed

+136
-55
lines changed

2 files changed

+136
-55
lines changed

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,28 @@
1818

1919

2020
class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
21+
"""Lift constants to initializers.
22+
23+
Attributes:
24+
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
25+
Default to False, where only Constants with the ``value`` attribute are lifted.
26+
"""
27+
28+
def __init__(self, lift_all_constants: bool = False):
29+
super().__init__()
30+
self._lift_all_constants = lift_all_constants
31+
2132
def call(self, model: ir.Model) -> ir.passes.PassResult:
22-
"""Convert constant nodes from node belonged graph to its initializers."""
2333
count = 0
2434
for node in ir.traversal.RecursiveGraphIterator(model.graph):
35+
assert node.graph is not None
2536
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
2637
continue
27-
38+
if node.outputs[0].is_graph_output():
39+
logger.debug(
40+
"Constant node '%s' is used as output, so it can't be lifted.", node.name
41+
)
42+
continue
2843
constant_node_attribute = set(node.attributes.keys())
2944
if len(constant_node_attribute) != 1:
3045
logger.debug(
@@ -36,13 +51,11 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
3651
initializer_name = node.outputs[0].name
3752
assert initializer_name is not None
3853
assert isinstance(attr_value, ir.Attr)
39-
tensor = _constant_node_attribute_to_tensor(
40-
attr_name, attr_value, initializer_name
54+
tensor = self._constant_node_attribute_to_tensor(
55+
node, attr_name, attr_value, initializer_name
4156
)
4257
if tensor is None:
43-
logger.debug(
44-
"Invalid constant node '%s' has unsupported attribute value", node.name
45-
)
58+
# The reason of None is logged in _constant_node_attribute_to_tensor
4659
continue
4760
# Register an initializer with the tensor value
4861
initializer = ir.Value(
@@ -51,7 +64,6 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
5164
type=ir.TensorType(tensor.dtype),
5265
const_value=tensor,
5366
)
54-
assert node.graph is not None
5567
assert isinstance(node.graph, ir.Graph)
5668
node.graph.register_initializer(initializer)
5769
# Replace the constant node with the initilizer
@@ -65,31 +77,38 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
6577
logger.debug("Lifted %s constants to initializers", count)
6678
return ir.passes.PassResult(model, modified=bool(count))
6779

80+
def _constant_node_attribute_to_tensor(
81+
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
82+
) -> ir.Tensor | None:
83+
"""Convert constant node attribute to tensor."""
84+
if not self._lift_all_constants and attr_name != "value":
85+
logger.debug(
86+
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
87+
)
88+
return None
6889

69-
def _constant_node_attribute_to_tensor(
70-
attr_name: str, attr_value: ir.Attr, initializer_name: str
71-
) -> ir.Tensor | None:
72-
"""Convert constant node attribute to tensor."""
73-
if attr_name == "value":
74-
tensor = attr_value.as_tensor() # type: ignore[union-attr]
75-
elif attr_name == "value_int":
76-
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name)
77-
elif attr_name == "value_ints":
78-
tensor = ir.tensor(
79-
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
80-
)
81-
elif attr_name == "value_float":
82-
tensor = ir.tensor(
83-
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
84-
)
85-
elif attr_name == "value_floats":
86-
tensor = ir.tensor(
87-
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
88-
)
89-
elif attr_name in ("value_string", "value_strings"):
90-
tensor = ir.StringTensor(
91-
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
92-
)
93-
else:
94-
tensor = None
95-
return tensor # type: ignore[return-value]
90+
if attr_name == "value":
91+
tensor = attr_value.as_tensor() # type: ignore[union-attr]
92+
elif attr_name == "value_int":
93+
tensor = ir.tensor(
94+
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
95+
)
96+
elif attr_name == "value_ints":
97+
tensor = ir.tensor(
98+
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
99+
)
100+
elif attr_name == "value_float":
101+
tensor = ir.tensor(
102+
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
103+
)
104+
elif attr_name == "value_floats":
105+
tensor = ir.tensor(
106+
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
107+
)
108+
elif attr_name in ("value_string", "value_strings"):
109+
tensor = ir.StringTensor(
110+
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
111+
)
112+
else:
113+
tensor = None
114+
return tensor # type: ignore[return-value]

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
class TestLiftConstantsToInitializersPass(unittest.TestCase):
1515
@parameterized.parameterized.expand(
1616
[
17-
(ir.DataType.FLOAT,),
18-
(ir.DataType.INT64,),
17+
(ir.DataType.FLOAT, True),
18+
(ir.DataType.FLOAT, False),
19+
(ir.DataType.INT64, True),
20+
(ir.DataType.INT64, False),
1921
]
2022
)
21-
def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype):
23+
def test_pass_with_lifting_float_and_int_constants_to_initializers(
24+
self, ir_dtype: ir.DataType, lift_all_constants: bool
25+
):
2226
inputs = [
2327
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
2428
ir.Value(
@@ -51,7 +55,9 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp
5155
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)
5256

5357
# Perform lift constants to initializers
54-
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
58+
result = constant_manipulation.LiftConstantsToInitializersPass(
59+
lift_all_constants=lift_all_constants
60+
)(model)
5561
self.assertTrue(result.modified)
5662
# Check that the constant node is lifted to an initializer
5763
self.assertEqual(len(result.model.graph.initializers), 1)
@@ -67,7 +73,15 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp
6773
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
6874
)
6975

70-
def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
76+
@parameterized.parameterized.expand(
77+
[
78+
(True,),
79+
(False,),
80+
]
81+
)
82+
def test_pass_with_lifting_constants_to_initializers_within_subgraph(
83+
self, lift_all_constants: bool
84+
):
7185
input_value = ir.Value(
7286
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
7387
)
@@ -115,7 +129,9 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
115129
graph=main_graph,
116130
ir_version=10,
117131
)
118-
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
132+
result = constant_manipulation.LiftConstantsToInitializersPass(
133+
lift_all_constants=lift_all_constants
134+
)(model)
119135
self.assertTrue(result.modified)
120136
# Check that the constant node is lifted to the subgraph initializers
121137
for node in ir.traversal.RecursiveGraphIterator(result.model.graph):
@@ -136,16 +152,26 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
136152

137153
@parameterized.parameterized.expand(
138154
[
139-
(1.0, "value_float", np.float32),
140-
(1, "value_int", np.int64),
141-
("hello world!", "value_string", np.bytes_),
142-
([1.0, 2.0, 3.0], "value_floats", np.float32),
143-
([1, 2, 3], "value_ints", np.int64),
144-
(["hello world!", "thank you."], "value_strings", np.bytes_),
155+
(1.0, "value_float", np.float32, True),
156+
(1.0, "value_float", np.float32, False),
157+
(1, "value_int", np.int64, True),
158+
(1, "value_int", np.int64, False),
159+
("hello world!", "value_string", np.bytes_, True),
160+
("hello world!", "value_string", np.bytes_, False),
161+
([1.0, 2.0, 3.0], "value_floats", np.float32, True),
162+
([1.0, 2.0, 3.0], "value_floats", np.float32, False),
163+
([1, 2, 3], "value_ints", np.int64, True),
164+
([1, 2, 3], "value_ints", np.int64, False),
165+
(["hello world!", "thank you."], "value_strings", np.bytes_, True),
166+
(["hello world!", "thank you."], "value_strings", np.bytes_, False),
145167
]
146168
)
147169
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
148-
self, value, constant_attribute, np_dtype
170+
self,
171+
value: float | int | str | list[float] | list[int] | list[str],
172+
constant_attribute: str,
173+
np_dtype: type[np.dtype],
174+
lift_all_constants: bool,
149175
):
150176
input_value = ir.Value(
151177
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
@@ -179,11 +205,47 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
179205
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)
180206

181207
# Perform lift constants to initializers
182-
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
183-
self.assertTrue(result.modified)
184-
# Check that the constant node is lifted to an initializer
185-
self.assertEqual(len(result.model.graph.initializers), 1)
186-
np.testing.assert_array_equal(
187-
result.model.graph.initializers["val_1"].const_value.numpy(),
188-
np.array(constant_value, dtype=np_dtype),
208+
result = constant_manipulation.LiftConstantsToInitializersPass(
209+
lift_all_constants=lift_all_constants
210+
)(model)
211+
if lift_all_constants:
212+
self.assertTrue(result.modified)
213+
# Check that the constant node is lifted to an initializer
214+
self.assertEqual(len(result.model.graph.initializers), 1)
215+
np.testing.assert_array_equal(
216+
result.model.graph.initializers["val_1"].const_value.numpy(),
217+
np.array(constant_value, dtype=np_dtype),
218+
)
219+
else:
220+
self.assertFalse(result.modified)
221+
# Check that the constant node is not lifted to an initializer
222+
self.assertEqual(len(result.model.graph.initializers), 0)
223+
224+
def test_not_lifting_constants_to_initializers_when_it_is_output(self):
225+
input_value = ir.Value(
226+
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
189227
)
228+
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)
229+
230+
constant_value = ir.tensor(np.random.rand(2, 3).astype(np.float32))
231+
const_node = ir.node(
232+
"Constant",
233+
inputs=[],
234+
attributes={"value": constant_value},
235+
num_outputs=1,
236+
)
237+
238+
model = ir.Model(
239+
graph=ir.Graph(
240+
inputs=[input_value],
241+
outputs=[identity_node_input.outputs[0], const_node.outputs[0]],
242+
nodes=[identity_node_input, const_node],
243+
opset_imports={"": 20},
244+
),
245+
ir_version=10,
246+
)
247+
248+
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
249+
self.assertFalse(result.modified)
250+
# Check that the constant node is not lifted to an initializer
251+
self.assertEqual(len(result.model.graph.initializers), 0)

0 commit comments

Comments
 (0)