Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,8 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
Comment thread
titaiwangms marked this conversation as resolved.
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
100 changes: 100 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

import numpy as np

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes in main graph to initializers."""
count = 0
for node in model.graph:
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue

allowed_constant_attributes = {
"value",
"value_int",
"value_ints",
"value_float",
"value_floats",
"value_string",
"value_strings",
}
constant_node_attribute = set(node.attributes.keys())
if len(constant_node_attribute) != 1:
logger.debug(

Check warning on line 39 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L39

Added line #L39 was not covered by tests
"Invalid constant node '%s' has more than one attribute", node.name
)
continue

Check warning on line 42 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L42

Added line #L42 was not covered by tests
if constant_node_attribute not in allowed_constant_attributes:
logger.debug("Invalid constant node '%s' has unsupported attribute", node.name)
continue

initializer_name = node.outputs[0].name
assert initializer_name is not None

Check warning on line 48 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L47-L48

Added lines #L47 - L48 were not covered by tests
# The value of attribute can only be ir.Attr, as
# ir.RefAttr is only defined in Functions.
attr_value = node.attributes[constant_node_attribute]

Check warning on line 51 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L51

Added line #L51 was not covered by tests
Comment thread Fixed
if constant_node_attribute == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]

Check warning on line 53 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L53

Added line #L53 was not covered by tests
elif constant_node_attribute == "value_int":
tensor = ir.Tensor(

Check warning on line 55 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L55

Added line #L55 was not covered by tests
np.array(attr_value.as_int(), dtype=np.int64), name=initializer_name
Comment thread Fixed
)
elif constant_node_attribute == "value_ints":
tensor = ir.Tensor(

Check warning on line 59 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L59

Added line #L59 was not covered by tests
np.array(attr_value.as_ints(), dtype=np.int64), name=initializer_name
Comment thread Fixed
)
elif constant_node_attribute == "value_float":
tensor = ir.Tensor(

Check warning on line 63 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L63

Added line #L63 was not covered by tests
np.array(attr_value.as_float(), dtype=np.float32), name=initializer_name
Comment thread Fixed
)
elif constant_node_attribute == "value_floats":
tensor = ir.Tensor(

Check warning on line 67 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L67

Added line #L67 was not covered by tests
np.array(attr_value.as_floats(), dtype=np.float32), name=initializer_name
Comment thread Fixed
)
elif constant_node_attribute == "value_string":
tensor = ir.Tensor(

Check warning on line 71 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L71

Added line #L71 was not covered by tests
np.array(attr_value.as_string(), dtype=np.object_), name=initializer_name
Comment thread Fixed
)
elif constant_node_attribute == "value_strings":
tensor = ir.Tensor(

Check warning on line 75 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L75

Added line #L75 was not covered by tests
np.array(attr_value.as_strings(), dtype=np.object_), name=initializer_name
Comment thread Fixed
)
else:
logger.debug("Invalid constant node '%s' has unsupported attribute", node.name)
continue

Check warning on line 80 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L79-L80

Added lines #L79 - L80 were not covered by tests
# Register an initializer with the tensor value
initializer = ir.Value(

Check warning on line 82 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L82

Added line #L82 was not covered by tests
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
# TODO(titaiwang): Is it possible that the initializer name has
# been taken?
model.graph.register_initializer(initializer)

Check warning on line 90 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L90

Added line #L90 was not covered by tests
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
model.graph.remove(node, safe=True)
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
count += 1
logger.info(

Check warning on line 95 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L92-L95

Added lines #L92 - L95 were not covered by tests
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.info("Lifted %s constants to initializers", count)

Check warning on line 99 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L99

Added line #L99 was not covered by tests
return ir.passes.PassResult(model, modified=bool(count))
62 changes: 62 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.
from __future__ import annotations
Comment thread Fixed

import unittest

import numpy as np

from onnxscript import ir
from onnxscript.ir.passes.common import constant_manipulation


class TestLiftConstantsToInitializersPass(unittest.TestCase):
def test_pass_with_lifting_constants_to_initializers(self):
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
),
ir.Value(
name="input_b",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)

Check warning on line 51 in onnxscript/ir/passes/common/constant_manipulation_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation_test.py#L51

Added line #L51 was not covered by tests
# Check the value
self.assertEqual(

Check warning on line 53 in onnxscript/ir/passes/common/constant_manipulation_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation_test.py#L53

Added line #L53 was not covered by tests
result.model.graph.initializers[
"val_0"
].const_value, # name created by name_authority
constant_tensor,
)
# And 0 constant node
self.assertEqual(

Check warning on line 60 in onnxscript/ir/passes/common/constant_manipulation_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation_test.py#L60

Added line #L60 was not covered by tests
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
Expand Down Expand Up @@ -70,6 +71,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading