Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,8 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
Comment thread
titaiwangms marked this conversation as resolved.
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
52 changes: 52 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes in main graph to initializers."""
count = 0
for node in model.graph:
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue
if "value" not in node.attributes:
logger.debug("Constant node '%s' has no 'value' attribute", node.name)
continue
# The value of attribute can only be ir.Attr, as
# ir.RefAttr is only defined in Functions.
tensor = node.attributes["value"].as_tensor() # type: ignore[union-attr]
# Register an initializer with the tensor value
initializer_name = node.outputs[0].name
assert initializer_name is not None
initializer = ir.Value(
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
# TODO(titaiwang): Is it possible that the initializer name has
# been taken?
model.graph.register_initializer(initializer)
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
model.graph.remove(node, safe=True)
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
count += 1
logger.info(
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.info("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))
61 changes: 61 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.
from __future__ import annotations
Comment thread Fixed

import unittest

import numpy as np

from onnxscript import ir
from onnxscript.ir.passes.common import constant_manipulation

import onnx
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed
Comment thread Fixed

class TestLiftConstantsToInitializersPass(unittest.TestCase):
def test_pass_with_lifting_constants_to_initializers(self):
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
),
ir.Value(
name="input_b",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
# Check the value
self.assertEqual(
result.model.graph.initializers["val_0"].const_value, # name created by name_authority
constant_tensor,
)
# And 0 constant node
self.assertEqual(
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)
Comment thread Fixed
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
Expand Down Expand Up @@ -70,6 +71,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading