Skip to content

Commit c7d13fb

Browse files
authored
Add input() and add_output() methods to GraphBuilder (#2828)
This pull request adds new methods to the `GraphBuilder` class to simplify the creation and management of graph inputs and outputs, and introduces corresponding unit tests to ensure their correct behavior. The changes improve the usability and reliability of the graph-building API. **Enhancements to the GraphBuilder API:** * Added a new `input` method to `GraphBuilder` for creating and registering graph input values with support for specifying name, dtype, shape, type, constant value, and metadata properties. * Added a new `add_output` method to `GraphBuilder` to append an output value to the graph and optionally rename it. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 864b785 commit c7d13fb

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

onnxscript/_internal/builder.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,54 @@ def initializer(
276276
self._graph.register_initializer(value)
277277
return value
278278

279+
def input(
280+
self,
281+
name: str,
282+
dtype: ir.DataType | None = None,
283+
shape: ir.Shape | Sequence[int | str | None] | None = None,
284+
*,
285+
type: ir.TypeProtocol | None = None,
286+
const_value: ir.TensorProtocol | None = None,
287+
metadata_props: dict[str, str] | None = None,
288+
) -> ir.Value:
289+
"""Create an input to the graph and return the corresponding ir.Value.
290+
291+
Args:
292+
name: The name of the value.
293+
dtype: The data type of the TensorType of the value. This is used only when type is None.
294+
shape: The shape of the value.
295+
type: The type of the value. Only one of dtype and type can be specified.
296+
const_value: The constant tensor that initializes the value. Supply this argument
297+
when you want to create an initializer. The type and shape can be obtained from the tensor.
298+
metadata_props: The metadata properties that will be serialized to the ONNX proto.
299+
300+
Returns:
301+
A Value object.
302+
"""
303+
value = ir.val(
304+
name=name,
305+
dtype=dtype,
306+
shape=shape,
307+
type=type,
308+
const_value=const_value,
309+
metadata_props=metadata_props,
310+
)
311+
self._graph.inputs.append(value)
312+
if const_value is not None:
313+
self._graph.register_initializer(value)
314+
return value
315+
316+
def add_output(self, value: ir.Value, name: str | None) -> None:
317+
"""Add an output to the graph.
318+
319+
Args:
320+
value: The ir.Value to add as an output.
321+
name: The name to assign to the output value. If None, no renaming is done.
322+
"""
323+
if name:
324+
value.name = name
325+
self._graph.outputs.append(value)
326+
279327
def _input_to_ir_value(
280328
self, value: VALUE_LIKE, like_type: ir.Value | None = None
281329
) -> ir.Value | None:

onnxscript/_internal/builder_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,102 @@ def test_output_names_are_unique_for_same_op_type(self):
545545
names = [t1.name, t2.name, t3.name]
546546
self.assertEqual(len(set(names)), 3)
547547

548+
def test_input_creates_and_registers_graph_input(self):
549+
"""Test that GraphBuilder.input creates and appends a graph input value."""
550+
graph = ir.Graph(
551+
name="test_model",
552+
inputs=[],
553+
outputs=[],
554+
nodes=[],
555+
opset_imports={"": _default_opset_version},
556+
)
557+
graph_builder = builder.GraphBuilder(graph)
558+
559+
value = graph_builder.input("data", dtype=ir.DataType.FLOAT, shape=[2, 3])
560+
561+
self.assertEqual(value.name, "data")
562+
self.assertEqual(value.type.dtype, ir.DataType.FLOAT)
563+
self.assertEqual(list(value.shape), [2, 3])
564+
self.assertEqual(len(graph.inputs), 1)
565+
self.assertIs(graph.inputs[0], value)
566+
567+
def test_input_with_const_value_registers_initializer(self):
568+
"""Test that GraphBuilder.input registers initializer when const_value is provided."""
569+
graph = ir.Graph(
570+
name="test_model",
571+
inputs=[],
572+
outputs=[],
573+
nodes=[],
574+
opset_imports={"": _default_opset_version},
575+
)
576+
graph_builder = builder.GraphBuilder(graph)
577+
578+
const_tensor = ir.tensor([1.0, 2.0], dtype=ir.DataType.FLOAT, name="const_data")
579+
value = graph_builder.input("const_input", const_value=const_tensor)
580+
581+
self.assertEqual(len(graph.inputs), 1)
582+
self.assertIs(graph.inputs[0], value)
583+
self.assertIn("const_input", graph.initializers)
584+
self.assertIs(graph.initializers["const_input"], value)
585+
self.assertIs(value.const_value, const_tensor)
586+
587+
def test_input_without_const_value_does_not_register_initializer(self):
588+
"""Test that GraphBuilder.input does not register initializer without const_value."""
589+
graph = ir.Graph(
590+
name="test_model",
591+
inputs=[],
592+
outputs=[],
593+
nodes=[],
594+
opset_imports={"": _default_opset_version},
595+
)
596+
graph_builder = builder.GraphBuilder(graph)
597+
598+
value = graph_builder.input("regular_input", dtype=ir.DataType.FLOAT, shape=[2])
599+
600+
self.assertEqual(len(graph.inputs), 1)
601+
self.assertIs(graph.inputs[0], value)
602+
self.assertNotIn("regular_input", graph.initializers)
603+
604+
def test_add_output_renames_and_registers_output(self):
605+
"""Test that GraphBuilder.add_output renames (optionally) and appends outputs."""
606+
graph = ir.Graph(
607+
name="test_model",
608+
inputs=[],
609+
outputs=[],
610+
nodes=[],
611+
opset_imports={"": _default_opset_version},
612+
)
613+
graph_builder = builder.GraphBuilder(graph)
614+
615+
output = ir.Value(name="old_name")
616+
graph_builder.add_output(output, "new_name")
617+
618+
self.assertEqual(output.name, "new_name")
619+
self.assertEqual(len(graph.outputs), 1)
620+
self.assertIs(graph.outputs[0], output)
621+
622+
def test_initializer_qualification_behavior(self):
623+
"""Test that GraphBuilder.initializer qualifies names unless explicitly disabled."""
624+
graph = ir.Graph(
625+
name="test_model",
626+
inputs=[],
627+
outputs=[],
628+
nodes=[],
629+
opset_imports={"": _default_opset_version},
630+
)
631+
graph_builder = builder.GraphBuilder(graph)
632+
633+
graph_builder.push_module("layer1")
634+
qualified = graph_builder.initializer(ir.tensor([1.0], name="w"), name="weight")
635+
unqualified = graph_builder.initializer(
636+
ir.tensor([2.0], name="b"), name="bias", qualify=False
637+
)
638+
639+
self.assertEqual(qualified.name, "layer1.weight")
640+
self.assertEqual(unqualified.name, "bias")
641+
self.assertIn("layer1.weight", graph.initializers)
642+
self.assertIn("bias", graph.initializers)
643+
548644
def test_multi_output_names_are_unique(self):
549645
"""Test that multi-output ops produce unique names with counter suffix."""
550646
op, x, y = _create_builder_with_inputs()

0 commit comments

Comments
 (0)