Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import onnxscript.rewriter._ir_utils as _ir_utils
import onnxscript.rewriter._matcher as _matcher
import onnxscript.rewriter._pattern_ir as _pattern_ir
import onnxscript.utils.metadata_merger as metadata_merger
from onnxscript import ir
from onnxscript.ir import _tape, convenience

Expand Down Expand Up @@ -614,6 +615,14 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:
overload += 1


# TODO(rama): Make this user-configurable. Perhaps allowing RewriteRuleSet to accept
# a MetadataMerger object should be sufficient. Using None will avoid expensive
# metadata merging when efficiency is important.
_default_metadata_merger: metadata_merger.MetadataMerger | None = (
metadata_merger.MetadataMerger({RULE_NAME_TAG: metadata_merger.comma_separator_merger})
)


class RewriteRuleSet:
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
if not rules:
Expand Down Expand Up @@ -740,6 +749,10 @@ def _apply_to_graph_or_function(
delta.new_outputs,
)

merger = _default_metadata_merger
if merger is not None:
merger.copy_merged_metadata(delta.match.nodes, delta.new_nodes)

count += 1
break

Expand Down
99 changes: 99 additions & 0 deletions onnxscript/utils/metadata_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Merging metadata_props"""

from __future__ import annotations

from typing import Callable, Iterable

import onnx_ir as ir

# Utilities for merging metadata properties, represented as strings.
# The merging-logic will take care of special cases like missing metadata or
# empty string metadata, and so the functions defined below need not handle
# special cases like empty string. (This does assume that an empty string is
# the same as no metadata, which is a reasonable assumption for most metadata.)

StringMerger = Callable[[str, str], str]


def overwrite(_: str, new: str) -> str:
return new


def join(separator: str) -> StringMerger:
"""Creates a StringMerger that joins two strings with the given separator.

Args:
separator (str): The separator to use when joining the strings.

Returns:
StringMerger: A function that joins two strings with the specified separator.
"""

def merger(first: str, second: str) -> str:
return f"{first}{separator}{second}"

return merger


comma_separator_merger = join(", ")


class MetadataMerger:
"""Merges metadata properties using specified merging logic.

Attributes:
mergers: A mapping from metadata property keys to their corresponding merging functions.
default: The default merging function to use when a specific key does not have a defined merger.
If None, the first value is used. (Specify `overwrite` to always use the second value.)
"""

def __init__(
self, mergers: dict[str, StringMerger], default: StringMerger | None = None
) -> None:
self.mergers = mergers
self.default = default

def update_dict(self, updated: dict[str, str], updates: dict[str, str]) -> None:
"""Updates the first metadata property dictionary with values from the second.

Args:
updated: The metadata dictionary to be updated.
updates: The updates metadata dictionary.
"""
for key, new_value in updates.items():
if new_value == "":
continue
if (key in updated) and ((updated_value := updated[key]) != ""):
merger = self.mergers.get(key, self.default)
if merger is not None:
updated[key] = merger(updated_value, new_value)
else:
updated[key] = new_value

def copy_merged_metadata(
Comment thread Fixed
self, from_nodes: Iterable[ir.Node], to: ir.Node | Iterable[ir.Node]
) -> None:
"""Merges metadata from multiple nodes and assigns it to a target node.

Args:
from_nodes: The source nodes from which to merge metadata.
to_node: The target node to which the merged metadata will be assigned.
"""
if isinstance(to, ir.Node):
updated = to.metadata_props
for node in from_nodes:
self.update_dict(updated, node.metadata_props)
elif len(to) == 1:
# Handle single node in iterable case
target_node = next(iter(to))
updated = target_node.metadata_props
for node in from_nodes:
self.update_dict(updated, node.metadata_props)
else:
merged_metadata: dict[str, str] = {}
for node in from_nodes:
self.update_dict(merged_metadata, node.metadata_props)
for target_node in to:
self.update_dict(target_node.metadata_props, merged_metadata)
Loading