Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,10 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# NOTE: input and initializer value names could be duplicated
# to auto-generated value names from name authority and crash ort
# https://github.com/microsoft/onnxruntime/blob/bc7b07dbb41a2f441dbed1a91855563ba0dd8a31/onnxruntime/core/graph/graph.cc#L1536
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2003,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
Comment thread
titaiwangms marked this conversation as resolved.
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
52 changes: 52 additions & 0 deletions onnxscript/ir/passes/common/lift_constants_to_initializers.py
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes in main graph to initializers."""
count = 0
for node in model.graph:
Comment thread
justinchuby marked this conversation as resolved.
Outdated
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue
if "value" not in node.attributes:
Comment thread
justinchuby marked this conversation as resolved.
Outdated
logger.debug("Constant node '%s' has no 'value' attribute", node.name)
continue
# The value of attribute can only be ir.Attr, as
# ir.RefAttr is only defined in Functions.
tensor = node.attributes["value"].as_tensor() # type: ignore[union-attr]
# Register an initializer with the tensor value
initializer_name = node.outputs[0].name
assert initializer_name is not None
initializer = ir.Value(
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
# TODO(titaiwang): Is it possible that the initializer name has
# been taken?
model.graph.initializers[initializer_name] = initializer
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
model.graph.remove(node, safe=True)
count += 1
logger.info(
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.info("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))
55 changes: 55 additions & 0 deletions onnxscript/ir/passes/common/lift_constants_to_initializers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np

from onnxscript import ir
from onnxscript.ir.passes.common import lift_constants_to_initializers


class TestLiftConstantsToInitializersPass(unittest.TestCase):
def test_pass_with_lifting_constants_to_initializers(self):
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
),
ir.Value(
name="input_b",
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = lift_constants_to_initializers.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
# And 0 constant node
self.assertEqual(
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.lift_constants_to_initializers
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
Expand Down Expand Up @@ -70,6 +71,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.lift_constants_to_initializers.LiftConstantsToInitializersPass(),
Comment thread Fixed
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading