Skip to content

Commit 87b338c

Browse files
gramalingamCopilot
andcommitted
fix: normalize cache key dtype to prevent initializer name collisions
When a Python int literal (e.g. `1`) is used in both untyped positions (like a Gather index, where the ONNX schema type variable is unbound) and typed positions (like Add with an INT64 tensor), the constant cache created two entries: (1, None) and (1, INT64). Both generated the same initializer name 'const_1_i64' but as different ir.Value objects, causing register_initializer to raise ValueError. Fix: before cache lookup, normalize dtype=None to the default ONNX dtype for the Python type (_PYTHON_TYPE_TO_DTYPE: int->INT64, float->FLOAT). This merges both entries into a single cache key and reuses the same ir.Value. Applied to both scalar and sequence (list/tuple) branches. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: G Ramalingam <grama@microsoft.com>
1 parent 13f265c commit 87b338c

2 files changed

Lines changed: 71 additions & 6 deletions

File tree

onnxscript/_internal/builder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,16 @@ def _input_to_ir_value(
346346
# For simple scalar/sequence constants, use a cache to avoid duplicate initializers.
347347
# These are shared across layers, so we don't qualify the name with context prefix.
348348
if isinstance(value, (int, float, bool, str)):
349+
# Normalize dtype: when None, use the default ONNX type for the
350+
# Python scalar so that (value, None) and (value, default_dtype)
351+
# share one cache entry and one initializer name.
352+
if dtype is None:
353+
dtype = _PYTHON_TYPE_TO_DTYPE.get(type(value))
349354
cache_key = (value, dtype)
350355
if cache_key in self._constant_cache:
351356
ir_value = self._constant_cache[cache_key]
352357
else:
353-
type_suffix = (
354-
_dtype_suffix(dtype) if dtype is not None else _type_suffix(type(value))
355-
)
358+
type_suffix = _dtype_suffix(dtype) if dtype is not None else ""
356359
name = _constant_name(value, type_suffix, len(self._constant_cache))
357360
tensor = ir.tensor(value, dtype=dtype, name=name)
358361
ir_value = self.initializer(tensor, name=name, qualify=False)
@@ -363,13 +366,14 @@ def _input_to_ir_value(
363366
and all(isinstance(v, type(value[0])) for v in value)
364367
and isinstance(value[0], (int, float, bool, str))
365368
):
369+
# Same normalization for sequences of scalars.
370+
if dtype is None:
371+
dtype = _PYTHON_TYPE_TO_DTYPE.get(type(value[0]))
366372
cache_key = (tuple(value), dtype)
367373
if cache_key in self._constant_cache:
368374
ir_value = self._constant_cache[cache_key]
369375
else:
370-
type_suffix = (
371-
_dtype_suffix(dtype) if dtype is not None else _type_suffix(type(value[0]))
372-
)
376+
type_suffix = _dtype_suffix(dtype) if dtype is not None else ""
373377
name = _constant_name(value, type_suffix, len(self._constant_cache))
374378
tensor = ir.tensor(list(value), dtype=dtype, name=name)
375379
ir_value = self.initializer(tensor, name=name, qualify=False)

onnxscript/_internal/builder_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,67 @@ def test_int_constant_with_unknown_type_uses_cast_like(self):
514514
# Add should use the CastLike output, not the raw constant
515515
self.assertIs(add_node.inputs[1], cast_like_node.outputs[0])
516516

517+
def test_int_literal_no_clash_across_typed_and_untyped_contexts(self):
518+
"""Test that the same int literal used in typed and untyped positions
519+
does not cause an initializer name collision.
520+
521+
Regression test: previously, (1, None) and (1, INT64) were separate
522+
cache keys but generated the same name 'const_1_i64', causing
523+
register_initializer to raise ValueError.
524+
"""
525+
graph = ir.Graph(
526+
name="test_model",
527+
inputs=[],
528+
outputs=[],
529+
nodes=[],
530+
opset_imports={"": _default_opset_version},
531+
)
532+
x = ir.Value(name="x", type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape([3]))
533+
graph.inputs.append(x)
534+
535+
graph_builder = builder.GraphBuilder(graph)
536+
op = graph_builder.op
537+
538+
# Gather index: int literal in untyped position (Tind has no binding)
539+
_ = op.Gather(x, 1, axis=0)
540+
# Add: int literal in typed position (T bound to INT64 from x)
541+
_ = op.Add(x, 1)
542+
543+
# Both ops should share the same initializer (same ir.Value object)
544+
gather_node = list(graph)[0]
545+
add_node = list(graph)[1]
546+
self.assertIs(gather_node.inputs[1], add_node.inputs[1])
547+
self.assertEqual(gather_node.inputs[1].name, "const_1_i64")
548+
549+
def test_int_list_no_clash_across_typed_and_untyped_contexts(self):
550+
"""Test that the same int list used in typed and untyped positions
551+
does not cause an initializer name collision (sequence variant)."""
552+
graph = ir.Graph(
553+
name="test_model",
554+
inputs=[],
555+
outputs=[],
556+
nodes=[],
557+
opset_imports={"": _default_opset_version},
558+
)
559+
x = ir.Value(
560+
name="x", type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape([2, 3])
561+
)
562+
graph.inputs.append(x)
563+
564+
graph_builder = builder.GraphBuilder(graph)
565+
op = graph_builder.op
566+
567+
# Reshape target: int list in untyped position
568+
_ = op.Reshape(x, [3, 2])
569+
# Add with a constant tensor of same values in typed position
570+
_ = op.Add(x, [3, 2])
571+
572+
# Should not raise; both should share the same initializer
573+
nodes = list(graph)
574+
reshape_node = nodes[0]
575+
add_node = nodes[1]
576+
self.assertIs(reshape_node.inputs[1], add_node.inputs[1])
577+
517578
def test_pop_module_raises_on_empty_stack(self):
518579
"""Test that pop_module raises RuntimeError when no module has been pushed."""
519580
op, _, _ = _create_builder_with_inputs()

0 commit comments

Comments
 (0)