Skip to content

Commit 22df674

Browse files
committed
add new tests
1 parent 3b15f45 commit 22df674

2 files changed

Lines changed: 83 additions & 31 deletions

File tree

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
5252
const_value=tensor,
5353
)
5454
assert node.graph is not None
55+
assert isinstance(node.graph, ir.Graph)
5556
node.graph.register_initializer(initializer)
5657
# Replace the constant node with the initilizer
5758
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase):
1818
(ir.DataType.INT64, np.int64),
1919
]
2020
)
21-
def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype):
21+
def test_pass_with_lifting_float_and_int_constants_to_initializers(
22+
self, ir_dtype, numpy_dtype
23+
):
2224
inputs = [
2325
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
2426
ir.Value(
@@ -29,10 +31,11 @@ def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype
2931
]
3032

3133
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype))
32-
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
33-
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
34-
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
35-
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])
34+
const_node = ir.node(
35+
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
36+
)
37+
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
38+
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])
3639

3740
model = ir.Model(
3841
graph=ir.Graph(
@@ -72,40 +75,34 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
7275
)
7376

7477
then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
75-
attribute = ir.convenience.convert_attributes({"value": then_constant_tensor})
76-
then_const_node = ir.Node(
77-
"", "Constant", inputs=[], attributes=attribute, num_outputs=1
78+
then_const_node = ir.node(
79+
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1
7880
)
7981
# then branch adds the constant to the input
8082
# else branch multiplies the input by the constant
81-
add_node = ir.Node("", "Add", inputs=[input_value, then_const_node.outputs[0]])
83+
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
8284
then_graph = ir.Graph(
8385
inputs=[input_value],
8486
outputs=[add_node.outputs[0]],
8587
nodes=[then_const_node, add_node],
8688
opset_imports={"": 20},
8789
)
8890
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
89-
attribute = ir.convenience.convert_attributes({"value": else_constant_tensor})
90-
else_const_node = ir.Node(
91-
"", "Constant", inputs=[], attributes=attribute, num_outputs=1
91+
else_const_node = ir.node(
92+
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
9293
)
93-
mul_node = ir.Node("", "Mul", inputs=[input_value, else_const_node.outputs[0]])
94+
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
9495
else_graph = ir.Graph(
9596
inputs=[input_value],
9697
outputs=[mul_node.outputs[0]],
9798
nodes=[else_const_node, mul_node],
9899
opset_imports={"": 20},
99100
)
100101
# create a conditional node that uses the then and else graphs
101-
attribute = ir.convenience.convert_attributes(
102-
{"then_branch": then_graph, "else_branch": else_graph}
103-
)
104-
cond_node = ir.Node(
105-
"",
102+
cond_node = ir.node(
106103
"If",
107104
inputs=[input_value],
108-
attributes=attribute,
105+
attributes={"then_branch": then_graph, "else_branch": else_graph},
109106
num_outputs=1,
110107
)
111108
# construnct the model
@@ -128,15 +125,69 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
128125
raise AssertionError(
129126
f"Constant node '{node.name}' was not lifted to initializers"
130127
)
131-
if node.op_type == "Add":
132-
self.assertEqual(len(node.graph.initializers), 1)
133-
self.assertEqual(
134-
node.graph.initializers["val_0"].const_value,
135-
then_constant_tensor,
136-
)
137-
if node.op_type == "Mul":
138-
self.assertEqual(len(node.graph.initializers), 1)
139-
self.assertEqual(
140-
node.graph.initializers["val_0"].const_value,
141-
else_constant_tensor,
142-
)
128+
self.assertEqual(len(else_graph.initializers), 1)
129+
self.assertEqual(len(then_graph.initializers), 1)
130+
self.assertEqual(
131+
else_graph.initializers["val_0"].const_value,
132+
else_constant_tensor,
133+
)
134+
self.assertEqual(
135+
then_graph.initializers["val_0"].const_value,
136+
then_constant_tensor,
137+
)
138+
139+
@parameterized.parameterized.expand(
140+
[
141+
(1.0, "value_float", np.float32),
142+
(1, "value_int", np.int64),
143+
("hello world!", "value_string", np.bytes_),
144+
([1.0, 2.0, 3.0], "value_floats", np.float32),
145+
([1, 2, 3], "value_ints", np.int64),
146+
(["hello world!", "thank you."], "value_strings", np.bytes_),
147+
]
148+
)
149+
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
150+
self, value, constant_attribute, np_dtype
151+
):
152+
input_value = ir.Value(
153+
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
154+
)
155+
156+
constant_value = value
157+
const_node = ir.node(
158+
"Constant",
159+
inputs=[],
160+
attributes={constant_attribute: constant_value},
161+
num_outputs=1,
162+
)
163+
identity_node_constant = ir.node(
164+
"Identity", inputs=[const_node.outputs[0]], num_outputs=1
165+
)
166+
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)
167+
168+
model = ir.Model(
169+
graph=ir.Graph(
170+
inputs=[input_value],
171+
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]],
172+
nodes=[identity_node_input, const_node, identity_node_constant],
173+
opset_imports={"": 20},
174+
),
175+
ir_version=10,
176+
)
177+
178+
# Check that the initializer is not in the graph yet
179+
assert len(model.graph.initializers) == 0
180+
# And 1 constant node
181+
assert len([node for node in model.graph if node.op_type == "Constant"]) == 1
182+
183+
# Perform lift constants to initializers
184+
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
185+
assert result.modified
186+
# Check that the constant node is lifted to an initializer
187+
assert len(result.model.graph.initializers) == 1
188+
self.assertTrue(
189+
np.array_equal(
190+
result.model.graph.initializers["val_1"].const_value.raw,
191+
np.array(constant_value, dtype=np_dtype),
192+
)
193+
)

0 commit comments

Comments
 (0)