Skip to content

Commit 79c2f39

Browse files
committed
Fix renaming
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent da82e42 commit 79c2f39

1 file changed

Lines changed: 23 additions & 10 deletions

File tree

onnxscript/_internal/builder.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -895,18 +895,19 @@ def call_inline(
895895
graph = function.graph()
896896
else:
897897
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
898-
output_renaming: dict[str, str] = {}
899898
if _outputs is not None:
900899
if len(_outputs) != len(graph.outputs):
901900
raise ValueError(
902901
f"Number of provided output names {_outputs} does not match "
903902
f"number of function outputs {len(graph.outputs)}."
904903
)
905-
for output, name in zip(graph.outputs, _outputs):
906-
output_renaming[output.name] = self._qualify_value_name(name)
904+
# Compute desired output names before pushing prefix scope so they
905+
# are not affected by the prefix.
906+
desired_output_names: list[str] = [
907+
self._qualify_value_name(name) for name in _outputs
908+
]
907909
else:
908-
for output in graph.outputs:
909-
output_renaming[output.name] = self._qualify_value_name(output.name)
910+
desired_output_names = []
910911

911912
if _prefix:
912913
self.push_module(_prefix)
@@ -915,15 +916,27 @@ def call_inline(
915916
node_name_prefix = self._qualify_node_name(f"{function.name}_node_{count}/")
916917
nodes, outputs = _inliner.instantiate(graph, args, kwargs, prefix=node_name_prefix)
917918

919+
# Track final output values so we can rename them separately.
920+
# The inliner prefixes all names, which would prevent name-based lookup
921+
# from matching the original graph output names.
922+
output_value_ids = {id(v) for v in outputs if v is not None}
923+
918924
for node in nodes:
919925
for output in node.outputs:
920-
if output.name:
921-
if output.name in output_renaming:
922-
output.name = output_renaming[output.name]
923-
else:
924-
output.name = self._qualify_value_name(output.name)
926+
if output.name and id(output) not in output_value_ids:
927+
output.name = self._qualify_value_name(output.name)
925928
self.add_node(node)
926929

930+
# Apply names to final output values
931+
if desired_output_names:
932+
for output_val, name in zip(outputs, desired_output_names):
933+
if output_val is not None:
934+
output_val.name = name
935+
else:
936+
for output_val in outputs:
937+
if output_val is not None and output_val.name:
938+
output_val.name = self._qualify_value_name(output_val.name)
939+
927940
if _prefix:
928941
self.pop_module()
929942
return outputs if len(outputs) > 1 else outputs[0]

0 commit comments

Comments
 (0)