Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,8 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
Comment thread
titaiwangms marked this conversation as resolved.
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
95 changes: 95 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

import numpy as np

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes from node belonged graph to its initializers."""
count = 0
for node in ir.traversal.RecursiveGraphIterator(model.graph):
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue

constant_node_attribute = set(node.attributes.keys())
if len(constant_node_attribute) != 1:
logger.debug(

Check warning on line 30 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L30

Added line #L30 was not covered by tests
"Invalid constant node '%s' has more than one attribute", node.name
)
continue

Check warning on line 33 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L33

Added line #L33 was not covered by tests

attr_name, attr_value = next(iter(node.attributes.items()))
initializer_name = node.outputs[0].name
assert initializer_name is not None
assert isinstance(attr_value, ir.Attr)
tensor = _constant_node_attribute_to_tensor(
attr_name, attr_value, initializer_name
)
if tensor is None:
logger.debug(

Check warning on line 43 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L43

Added line #L43 was not covered by tests
"Invalid constant node '%s' has unsupported attribute value", node.name
)
continue

Check warning on line 46 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L46

Added line #L46 was not covered by tests
# Register an initializer with the tensor value
initializer = ir.Value(
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
assert node.graph is not None
assert isinstance(node.graph, ir.Graph)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby does this make sense? I think there should not be any ir.Function node coming out from recursive iterator?

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does make sense. Thanks! The reason why it was annotated with graph | function is that the “owning graph” can be a function when the node is part of a function. Maybe there are better ways to do it 🤔

Copy link
Copy Markdown
Collaborator

@gramalingam gramalingam Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Unrelated to this PR): Isn't a Function object a wrapper around a Graph object? Does node.graph not return that graph object even in the case of function nodes?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the graph in a function is private and not used directly. It is currently an implementation detail

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I was wrong. It is pointed to a graph when we call function.append, but it is not when we call ir.Node(graph=function). I need to figure out how to reconcile this. Suggestions appreciated. #2181

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node.graph.register_initializer(initializer)
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
node.graph.remove(node, safe=True)
count += 1
logger.debug(
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.debug("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))


def _constant_node_attribute_to_tensor(
attr_name: str, attr_value: ir.Attr, initializer_name: str
) -> ir.Tensor | None:
"""Convert constant node attribute to tensor."""
if attr_name == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]
elif attr_name == "value_int":
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name)
elif attr_name == "value_ints":
tensor = ir.tensor(
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
)
elif attr_name == "value_float":
tensor = ir.tensor(
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name == "value_floats":
tensor = ir.tensor(
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name in ("value_string", "value_strings"):
Comment thread
justinchuby marked this conversation as resolved.
tensor = ir.StringTensor(
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
)
else:
tensor = None

Check warning on line 94 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L94

Added line #L94 was not covered by tests
return tensor # type: ignore[return-value]
193 changes: 193 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.
from __future__ import annotations
Comment thread Fixed

import unittest

import numpy as np
import parameterized

from onnxscript import ir
from onnxscript.ir.passes.common import constant_manipulation


class TestLiftConstantsToInitializersPass(unittest.TestCase):
@parameterized.parameterized.expand(
[
(ir.DataType.FLOAT, np.float32),
(ir.DataType.INT64, np.int64),
]
)
def test_pass_with_lifting_float_and_int_constants_to_initializers(
self, ir_dtype, numpy_dtype
):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
ir.Value(
name="input_b",
type=ir.TensorType(ir_dtype),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype))
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
const_node = ir.node(
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
)
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
# Check the value
self.assertEqual(
result.model.graph.initializers[
"val_0"
].const_value, # name created by name_authority
constant_tensor,
)
# And 0 constant node
self.assertEqual(
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)

def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
then_const_node = ir.node(
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1
)
# then branch adds the constant to the input
# else branch multiplies the input by the constant
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
then_graph = ir.Graph(
inputs=[input_value],
outputs=[add_node.outputs[0]],
nodes=[then_const_node, add_node],
opset_imports={"": 20},
)
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
else_const_node = ir.node(
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
)
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
else_graph = ir.Graph(
inputs=[input_value],
outputs=[mul_node.outputs[0]],
nodes=[else_const_node, mul_node],
opset_imports={"": 20},
)
# create a conditional node that uses the then and else graphs
cond_node = ir.node(
"If",
inputs=[input_value],
attributes={"then_branch": then_graph, "else_branch": else_graph},
num_outputs=1,
)
# construnct the model
main_graph = ir.Graph(
inputs=[input_value],
outputs=cond_node.outputs,
nodes=[cond_node],
opset_imports={"": 20},
)
main_graph.sort()
model = ir.Model(
graph=main_graph,
ir_version=10,
)
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to the subgraph initializers
for node in ir.traversal.RecursiveGraphIterator(result.model.graph):
if node.op_type == "Constant":
raise AssertionError(

Check warning on line 125 in onnxscript/ir/passes/common/constant_manipulation_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation_test.py#L125

Added line #L125 was not covered by tests
f"Constant node '{node.name}' was not lifted to initializers"
)
self.assertEqual(len(else_graph.initializers), 1)
self.assertEqual(len(then_graph.initializers), 1)
self.assertEqual(
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
else_graph.initializers["val_0"].const_value,
else_constant_tensor,
)
self.assertEqual(
then_graph.initializers["val_0"].const_value,
then_constant_tensor,
)

@parameterized.parameterized.expand(
[
(1.0, "value_float", np.float32),
(1, "value_int", np.int64),
("hello world!", "value_string", np.bytes_),
([1.0, 2.0, 3.0], "value_floats", np.float32),
([1, 2, 3], "value_ints", np.int64),
(["hello world!", "thank you."], "value_strings", np.bytes_),
]
)
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
self, value, constant_attribute, np_dtype
):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

constant_value = value
const_node = ir.node(
"Constant",
inputs=[],
attributes={constant_attribute: constant_value},
num_outputs=1,
)
identity_node_constant = ir.node(
"Identity", inputs=[const_node.outputs[0]], num_outputs=1
)
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)

model = ir.Model(
graph=ir.Graph(
inputs=[input_value],
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]],
nodes=[identity_node_input, const_node, identity_node_constant],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
assert len(model.graph.initializers) == 0
# And 1 constant node
assert len([node for node in model.graph if node.op_type == "Constant"]) == 1
Comment thread
justinchuby marked this conversation as resolved.
Outdated

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
assert result.modified
Comment thread
justinchuby marked this conversation as resolved.
Outdated
# Check that the constant node is lifted to an initializer
assert len(result.model.graph.initializers) == 1
Comment thread
justinchuby marked this conversation as resolved.
Outdated
self.assertTrue(
Comment thread
justinchuby marked this conversation as resolved.
Outdated
np.array_equal(
result.model.graph.initializers["val_1"].const_value.raw,
np.array(constant_value, dtype=np_dtype),
)
)
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.unused_removal
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
Expand Down Expand Up @@ -52,6 +53,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading