diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index e13a3fa978..b408898f71 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1924,6 +1924,8 @@ def __init__( # Be sure the initialize the name authority before extending the nodes # because it is used to name the nodes and their outputs self._name_authority = _name_authority.NameAuthority() + # TODO(justinchuby): Trigger again if inputs or initializers are modified. + self._set_input_and_initializer_value_names_into_name_authority() # Call self.extend not self._nodes.extend so the graph reference is added to the nodes self.extend(nodes) @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]: def __reversed__(self) -> Iterator[Node]: return reversed(self._nodes) + def _set_input_and_initializer_value_names_into_name_authority(self): + for value in self.inputs: + self._name_authority.register_or_name_value(value) + for value in self.initializers.values(): + self._name_authority.register_or_name_value(value) + def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: """Set the graph reference for the node and assign names to it and its outputs if they don't have one.""" if node.graph is not None and node.graph is not self: diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py new file mode 100644 index 0000000000..3032b33d44 --- /dev/null +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Lift constants to initializers.""" + +from __future__ import annotations + +__all__ = [ + "LiftConstantsToInitializersPass", +] + +import logging + +import numpy as np + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +class LiftConstantsToInitializersPass(ir.passes.InPlacePass): + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Convert constant nodes from node belonged graph to its initializers.""" + count = 0 + for node in ir.traversal.RecursiveGraphIterator(model.graph): + if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): + continue + + constant_node_attribute = set(node.attributes.keys()) + if len(constant_node_attribute) != 1: + logger.debug( + "Invalid constant node '%s' has more than one attribute", node.name + ) + continue + + attr_name, attr_value = next(iter(node.attributes.items())) + initializer_name = node.outputs[0].name + assert initializer_name is not None + assert isinstance(attr_value, ir.Attr) + tensor = _constant_node_attribute_to_tensor( + attr_name, attr_value, initializer_name + ) + if tensor is None: + logger.debug( + "Invalid constant node '%s' has unsupported attribute value", node.name + ) + continue + # Register an initializer with the tensor value + initializer = ir.Value( + name=initializer_name, + shape=tensor.shape, # type: ignore[arg-type] + type=ir.TensorType(tensor.dtype), + const_value=tensor, + ) + assert node.graph is not None + assert isinstance(node.graph, ir.Graph) + node.graph.register_initializer(initializer) + # Replace the constant node with the initilizer + ir.convenience.replace_all_uses_with(node.outputs[0], initializer) + node.graph.remove(node, safe=True) + count += 1 + logger.debug( + "Converted constant node '%s' to initializer '%s'", node.name, initializer_name + ) + if count: + logger.debug("Lifted %s constants to initializers", count) + return ir.passes.PassResult(model, modified=bool(count)) + + +def _constant_node_attribute_to_tensor( + attr_name: str, attr_value: ir.Attr, initializer_name: str +) -> ir.Tensor | None: + """Convert constant node attribute to tensor.""" + if attr_name == "value": + tensor = attr_value.as_tensor() # type: ignore[union-attr] + elif attr_name == "value_int": + tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name) + elif attr_name == "value_ints": + tensor = ir.tensor( + attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name + ) + elif attr_name == "value_float": + tensor = ir.tensor( + attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name + ) + elif attr_name == "value_floats": + tensor = ir.tensor( + attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name + ) + elif attr_name in ("value_string", "value_strings"): + tensor = ir.StringTensor( + np.array(attr_value.value, dtype=np.bytes_), name=initializer_name + ) + else: + tensor = None + return tensor # type: ignore[return-value] diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py new file mode 100644 index 0000000000..2d1696e7fd --- /dev/null +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import parameterized + +from onnxscript import ir +from onnxscript.ir.passes.common import constant_manipulation + + +class TestLiftConstantsToInitializersPass(unittest.TestCase): + @parameterized.parameterized.expand( + [ + (ir.DataType.FLOAT,), + (ir.DataType.INT64,), + ] + ) + def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), + ir.Value( + name="input_b", + type=ir.TensorType(ir_dtype), + shape=ir.Shape((2, 3)), + ), + ] + + constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy())) + const_node = ir.node( + "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 + ) + add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=mul_node.outputs, + nodes=[const_node, add_node, mul_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Check that the initializer is not in the graph yet + self.assertEqual(len(model.graph.initializers), 0) + # And 1 constant node + self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) + + # Perform lift constants to initializers + result = constant_manipulation.LiftConstantsToInitializersPass()(model) + self.assertTrue(result.modified) + # Check that the constant node is lifted to an initializer + self.assertEqual(len(result.model.graph.initializers), 1) + # Check the value + self.assertEqual( + result.model.graph.initializers[ + "val_0" + ].const_value, # name created by name_authority + constant_tensor, + ) + # And 0 constant node + self.assertEqual( + len([node for node in result.model.graph if node.op_type == "Constant"]), 0 + ) + + def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + then_const_node = ir.node( + "Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 + ) + # then branch adds the constant to the input + # else branch multiplies the input by the constant + add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) + then_graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[then_const_node, add_node], + opset_imports={"": 20}, + ) + else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + else_const_node = ir.node( + "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 + ) + mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) + else_graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[else_const_node, mul_node], + opset_imports={"": 20}, + ) + # create a conditional node that uses the then and else graphs + cond_node = ir.node( + "If", + inputs=[input_value], + attributes={"then_branch": then_graph, "else_branch": else_graph}, + num_outputs=1, + ) + # construnct the model + main_graph = ir.Graph( + inputs=[input_value], + outputs=cond_node.outputs, + nodes=[cond_node], + opset_imports={"": 20}, + ) + main_graph.sort() + model = ir.Model( + graph=main_graph, + ir_version=10, + ) + result = constant_manipulation.LiftConstantsToInitializersPass()(model) + self.assertTrue(result.modified) + # Check that the constant node is lifted to the subgraph initializers + for node in ir.traversal.RecursiveGraphIterator(result.model.graph): + if node.op_type == "Constant": + raise AssertionError( + f"Constant node '{node.name}' was not lifted to initializers" + ) + self.assertEqual(len(else_graph.initializers), 1) + self.assertEqual(len(then_graph.initializers), 1) + self.assertIs( + else_graph.initializers["val_0"].const_value, + else_constant_tensor, + ) + self.assertIs( + then_graph.initializers["val_0"].const_value, + then_constant_tensor, + ) + + @parameterized.parameterized.expand( + [ + (1.0, "value_float", np.float32), + (1, "value_int", np.int64), + ("hello world!", "value_string", np.bytes_), + ([1.0, 2.0, 3.0], "value_floats", np.float32), + ([1, 2, 3], "value_ints", np.int64), + (["hello world!", "thank you."], "value_strings", np.bytes_), + ] + ) + def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( + self, value, constant_attribute, np_dtype + ): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + constant_value = value + const_node = ir.node( + "Constant", + inputs=[], + attributes={constant_attribute: constant_value}, + num_outputs=1, + ) + identity_node_constant = ir.node( + "Identity", inputs=[const_node.outputs[0]], num_outputs=1 + ) + identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) + + model = ir.Model( + graph=ir.Graph( + inputs=[input_value], + outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]], + nodes=[identity_node_input, const_node, identity_node_constant], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Check that the initializer is not in the graph yet + self.assertEqual(len(model.graph.initializers), 0) + # And 1 constant node + self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) + + # Perform lift constants to initializers + result = constant_manipulation.LiftConstantsToInitializersPass()(model) + self.assertTrue(result.modified) + # Check that the constant node is lifted to an initializer + self.assertEqual(len(result.model.graph.initializers), 1) + np.testing.assert_array_equal( + result.model.graph.initializers["val_1"].const_value.numpy(), + np.array(constant_value, dtype=np_dtype), + ) diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 4b2ab2223f..9dfeb53da3 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -4,6 +4,7 @@ import logging +import onnxscript.ir.passes.common.constant_manipulation import onnxscript.ir.passes.common.unused_removal from onnxscript import ir, rewriter from onnxscript.optimizer import _constant_folding, _inliner @@ -52,6 +53,7 @@ def optimize_ir( early_stop=stop_if_no_change, ), onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), + onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(), ) assert optimizer_pass.in_place result = optimizer_pass(model)