Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,16 @@ def _input_to_ir_value(
# For simple scalar/sequence constants, use a cache to avoid duplicate initializers.
# These are shared across layers, so we don't qualify the name with context prefix.
if isinstance(value, (int, float, bool, str)):
# Normalize dtype: when None, use the default ONNX type for the
# Python scalar so that (value, None) and (value, default_dtype)
# share one cache entry and one initializer name.
if dtype is None:
dtype = _PYTHON_TYPE_TO_DTYPE.get(type(value))
cache_key = (value, dtype)
if cache_key in self._constant_cache:
ir_value = self._constant_cache[cache_key]
else:
type_suffix = (
_dtype_suffix(dtype) if dtype is not None else _type_suffix(type(value))
)
type_suffix = _dtype_suffix(dtype) if dtype is not None else ""
name = _constant_name(value, type_suffix, len(self._constant_cache))
tensor = ir.tensor(value, dtype=dtype, name=name)
ir_value = self.initializer(tensor, name=name, qualify=False)
Expand All @@ -363,13 +366,14 @@ def _input_to_ir_value(
and all(isinstance(v, type(value[0])) for v in value)
and isinstance(value[0], (int, float, bool, str))
):
# Same normalization for sequences of scalars.
if dtype is None:
dtype = _PYTHON_TYPE_TO_DTYPE.get(type(value[0]))
cache_key = (tuple(value), dtype)
if cache_key in self._constant_cache:
ir_value = self._constant_cache[cache_key]
else:
type_suffix = (
_dtype_suffix(dtype) if dtype is not None else _type_suffix(type(value[0]))
)
type_suffix = _dtype_suffix(dtype) if dtype is not None else ""
name = _constant_name(value, type_suffix, len(self._constant_cache))
tensor = ir.tensor(list(value), dtype=dtype, name=name)
ir_value = self.initializer(tensor, name=name, qualify=False)
Expand Down
60 changes: 60 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,66 @@ def test_int_constant_with_unknown_type_uses_cast_like(self):
# Add should use the CastLike output, not the raw constant
self.assertIs(add_node.inputs[1], cast_like_node.outputs[0])

def test_int_literal_no_clash_across_typed_and_untyped_contexts(self):
"""Test that the same int literal used in typed and untyped positions
does not cause an initializer name collision.

Regression test: previously, (1, None) and (1, INT64) were separate
cache keys but generated the same name 'const_1_i64', causing
register_initializer to raise ValueError.
"""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
x = ir.Value(name="x", type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape([3]))
graph.inputs.append(x)

graph_builder = builder.GraphBuilder(graph)
op = graph_builder.op

# Gather index: int literal in untyped position (Tind has no binding)
_ = op.Gather(x, 1, axis=0)
# Add: int literal in typed position (T bound to INT64 from x)
_ = op.Add(x, 1)

# Both ops should share the same initializer (same ir.Value object)
gather_node = graph.node(0)
add_node = graph.node(1)
self.assertIs(gather_node.inputs[1], add_node.inputs[1])
self.assertEqual(gather_node.inputs[1].name, "const_1_i64")

def test_int_list_no_clash_across_typed_and_untyped_contexts(self):
"""Test that the same int list used in typed and untyped positions
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
does not cause an initializer name collision (sequence variant).
"""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
x = ir.Value(name="x", type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape([2, 3]))
graph.inputs.append(x)

graph_builder = builder.GraphBuilder(graph)
op = graph_builder.op

# Reshape target: int list in untyped position
_ = op.Reshape(x, [3, 2])
# Add with a constant tensor of same values in typed position
_ = op.Add(x, [3, 2])

# Should not raise; both should share the same initializer
nodes = list(graph)
reshape_node = nodes[0]
add_node = nodes[1]
self.assertIs(reshape_node.inputs[1], add_node.inputs[1])

def test_pop_module_raises_on_empty_stack(self):
"""Test that pop_module raises RuntimeError when no module has been pushed."""
op, _, _ = _create_builder_with_inputs()
Expand Down
Loading