Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,8 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
Comment thread
titaiwangms marked this conversation as resolved.
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
94 changes: 94 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

import numpy as np

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes from node belonged graph to its initializers."""
count = 0
for node in ir.traversal.RecursiveGraphIterator(model.graph):
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue

constant_node_attribute = set(node.attributes.keys())
if len(constant_node_attribute) != 1:
logger.debug(

Check warning on line 30 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L30

Added line #L30 was not covered by tests
"Invalid constant node '%s' has more than one attribute", node.name
)
continue

Check warning on line 33 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L33

Added line #L33 was not covered by tests

attr_name, attr_value = next(iter(node.attributes.items()))
initializer_name = node.outputs[0].name
assert initializer_name is not None
assert isinstance(attr_value, ir.Attr)
tensor = _constant_node_attribute_to_tensor(
attr_name, attr_value, initializer_name
)
if tensor is None:
logger.debug(

Check warning on line 43 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L43

Added line #L43 was not covered by tests
"Invalid constant node '%s' has unsupported attribute value", node.name
)
continue

Check warning on line 46 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L46

Added line #L46 was not covered by tests
# Register an initializer with the tensor value
initializer = ir.Value(
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
assert node.graph is not None
node.graph.register_initializer(initializer)
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
node.graph.remove(node, safe=True)
count += 1
logger.debug(
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.debug("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))


def _constant_node_attribute_to_tensor(
attr_name: str, attr_value: ir.Attr, initializer_name: str
) -> ir.Tensor | None:
"""Convert constant node attribute to tensor."""
if attr_name == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]
elif attr_name == "value_int":
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name)
elif attr_name == "value_ints":
tensor = ir.tensor(
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
)
elif attr_name == "value_float":
tensor = ir.tensor(
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name == "value_floats":
tensor = ir.tensor(

Check warning on line 85 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L85

Added line #L85 was not covered by tests
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name in ("value_string", "value_strings"):
Comment thread
justinchuby marked this conversation as resolved.
tensor = ir.StringTensor(

Check warning on line 89 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L89

Added line #L89 was not covered by tests
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
)
else:
tensor = None

Check warning on line 93 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L93

Added line #L93 was not covered by tests
return tensor # type: ignore[return-value]
142 changes: 142 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.
from __future__ import annotations
Comment thread Fixed

import unittest

import numpy as np
import parameterized

from onnxscript import ir
from onnxscript.ir.passes.common import constant_manipulation


class TestLiftConstantsToInitializersPass(unittest.TestCase):
@parameterized.parameterized.expand(
[
(ir.DataType.FLOAT, np.float32),
(ir.DataType.INT64, np.int64),
]
)
def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
ir.Value(
name="input_b",
type=ir.TensorType(ir_dtype),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype))
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
# Check the value
self.assertEqual(
result.model.graph.initializers[
"val_0"
].const_value, # name created by name_authority
constant_tensor,
)
# And 0 constant node
self.assertEqual(
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)

def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": then_constant_tensor})
then_const_node = ir.Node(
"", "Constant", inputs=[], attributes=attribute, num_outputs=1
)
Comment thread
justinchuby marked this conversation as resolved.
Outdated
# then branch adds the constant to the input
# else branch multiplies the input by the constant
add_node = ir.Node("", "Add", inputs=[input_value, then_const_node.outputs[0]])
then_graph = ir.Graph(
inputs=[input_value],
outputs=[add_node.outputs[0]],
nodes=[then_const_node, add_node],
opset_imports={"": 20},
)
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": else_constant_tensor})
else_const_node = ir.Node(
"", "Constant", inputs=[], attributes=attribute, num_outputs=1
)
mul_node = ir.Node("", "Mul", inputs=[input_value, else_const_node.outputs[0]])
else_graph = ir.Graph(
inputs=[input_value],
outputs=[mul_node.outputs[0]],
nodes=[else_const_node, mul_node],
opset_imports={"": 20},
)
# create a conditional node that uses the then and else graphs
attribute = ir.convenience.convert_attributes(
{"then_branch": then_graph, "else_branch": else_graph}
)
cond_node = ir.Node(
"",
"If",
inputs=[input_value],
attributes=attribute,
num_outputs=1,
)
# construnct the model
main_graph = ir.Graph(
inputs=[input_value],
outputs=cond_node.outputs,
nodes=[cond_node],
opset_imports={"": 20},
)
main_graph.sort()
model = ir.Model(
graph=main_graph,
ir_version=10,
)
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to the subgraph initializers
for node in ir.traversal.RecursiveGraphIterator(result.model.graph):
if node.op_type == "Constant":
raise AssertionError(

Check warning on line 128 in onnxscript/ir/passes/common/constant_manipulation_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation_test.py#L128

Added line #L128 was not covered by tests
f"Constant node '{node.name}' was not lifted to initializers"
)
if node.op_type == "Add":
self.assertEqual(len(node.graph.initializers), 1)
self.assertEqual(
node.graph.initializers["val_0"].const_value,
then_constant_tensor,
)
if node.op_type == "Mul":
self.assertEqual(len(node.graph.initializers), 1)
self.assertEqual(
node.graph.initializers["val_0"].const_value,
else_constant_tensor,
)
Comment thread
justinchuby marked this conversation as resolved.
Outdated
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
Expand Down Expand Up @@ -70,6 +71,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading