Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
const_value=tensor,
)
assert node.graph is not None
assert isinstance(node.graph, ir.Graph)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby does this make sense? I think there should not be any ir.Function node coming out from recursive iterator?

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does make sense. Thanks! The reason why it was annotated with graph | function is that the “owning graph” can be a function when the node is part of a function. Maybe there are better ways to do it 🤔

Copy link
Copy Markdown
Collaborator

@gramalingam gramalingam Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Unrelated to this PR): Isn't a Function object a wrapper around a Graph object? Does node.graph not return that graph object even in the case of function nodes?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the graph in a function is private and not used directly. It is currently an implementation detail

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I was wrong. It is pointed to a graph when we call function.append, but it is not when we call ir.Node(graph=function). I need to figure out how to reconcile this. Suggestions appreciated. #2181

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node.graph.register_initializer(initializer)
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
Expand Down
113 changes: 82 additions & 31 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase):
(ir.DataType.INT64, np.int64),
]
)
def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype):
def test_pass_with_lifting_float_and_int_constants_to_initializers(
self, ir_dtype, numpy_dtype
):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
ir.Value(
Expand All @@ -29,10 +31,11 @@ def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype))
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])
const_node = ir.node(
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
)
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
Expand Down Expand Up @@ -72,40 +75,34 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
)

then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": then_constant_tensor})
then_const_node = ir.Node(
"", "Constant", inputs=[], attributes=attribute, num_outputs=1
then_const_node = ir.node(
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1
)
# then branch adds the constant to the input
# else branch multiplies the input by the constant
add_node = ir.Node("", "Add", inputs=[input_value, then_const_node.outputs[0]])
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
then_graph = ir.Graph(
inputs=[input_value],
outputs=[add_node.outputs[0]],
nodes=[then_const_node, add_node],
opset_imports={"": 20},
)
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": else_constant_tensor})
else_const_node = ir.Node(
"", "Constant", inputs=[], attributes=attribute, num_outputs=1
else_const_node = ir.node(
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
)
mul_node = ir.Node("", "Mul", inputs=[input_value, else_const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
else_graph = ir.Graph(
inputs=[input_value],
outputs=[mul_node.outputs[0]],
nodes=[else_const_node, mul_node],
opset_imports={"": 20},
)
# create a conditional node that uses the then and else graphs
attribute = ir.convenience.convert_attributes(
{"then_branch": then_graph, "else_branch": else_graph}
)
cond_node = ir.Node(
"",
cond_node = ir.node(
"If",
inputs=[input_value],
attributes=attribute,
attributes={"then_branch": then_graph, "else_branch": else_graph},
num_outputs=1,
)
# construnct the model
Expand All @@ -128,15 +125,69 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
raise AssertionError(
f"Constant node '{node.name}' was not lifted to initializers"
)
if node.op_type == "Add":
self.assertEqual(len(node.graph.initializers), 1)
self.assertEqual(
node.graph.initializers["val_0"].const_value,
then_constant_tensor,
)
if node.op_type == "Mul":
self.assertEqual(len(node.graph.initializers), 1)
self.assertEqual(
node.graph.initializers["val_0"].const_value,
else_constant_tensor,
)
self.assertEqual(len(else_graph.initializers), 1)
self.assertEqual(len(then_graph.initializers), 1)
self.assertEqual(
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
else_graph.initializers["val_0"].const_value,
else_constant_tensor,
)
self.assertEqual(
then_graph.initializers["val_0"].const_value,
then_constant_tensor,
)

@parameterized.parameterized.expand(
[
(1.0, "value_float", np.float32),
(1, "value_int", np.int64),
("hello world!", "value_string", np.bytes_),
([1.0, 2.0, 3.0], "value_floats", np.float32),
([1, 2, 3], "value_ints", np.int64),
(["hello world!", "thank you."], "value_strings", np.bytes_),
]
)
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
self, value, constant_attribute, np_dtype
):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

constant_value = value
const_node = ir.node(
"Constant",
inputs=[],
attributes={constant_attribute: constant_value},
num_outputs=1,
)
identity_node_constant = ir.node(
"Identity", inputs=[const_node.outputs[0]], num_outputs=1
)
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)

model = ir.Model(
graph=ir.Graph(
inputs=[input_value],
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]],
nodes=[identity_node_input, const_node, identity_node_constant],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
assert len(model.graph.initializers) == 0
# And 1 constant node
assert len([node for node in model.graph if node.op_type == "Constant"]) == 1
Comment thread
justinchuby marked this conversation as resolved.
Outdated

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
assert result.modified
Comment thread
justinchuby marked this conversation as resolved.
Outdated
# Check that the constant node is lifted to an initializer
assert len(result.model.graph.initializers) == 1
Comment thread
justinchuby marked this conversation as resolved.
Outdated
self.assertTrue(
Comment thread
justinchuby marked this conversation as resolved.
Outdated
np.array_equal(
result.model.graph.initializers["val_1"].const_value.raw,
np.array(constant_value, dtype=np_dtype),
)
)
Loading