Skip to content

Commit 8b79816

Browse files
justinchubyCopilotCopilot
authored
Add onnxscript.nn module with Module and Parameter classes (#2819)
This pull request introduces a new PyTorch-like module interface for building ONNX graphs, enabling users to define reusable neural network components and manage parameters in a structured way. The main changes add the `onnxscript.nn` package, expose its API, and implement core classes for module and parameter management. Addition of ONNX neural network module interface: * Added new `onnxscript.nn` package and exposed it in the main `onnxscript` API, allowing users to access neural network module functionality. (`onnxscript/__init__.py`, [[1]](diffhunk://#diff-a1562cf8b37bd59e756ae3802f64fb4b0712845c4dd5a747d7db2a4212f1bfa5R8) [[2]](diffhunk://#diff-a1562cf8b37bd59e756ae3802f64fb4b0712845c4dd5a747d7db2a4212f1bfa5L130-R131) * Introduced `Module` and `Parameter` classes in `onnxscript/nn/_module.py` and `onnxscript/nn/_parameter.py`, providing a PyTorch-like interface for defining ONNX graph modules, registering parameters, and managing module hierarchies. (`onnxscript/nn/_module.py`, [[1]](diffhunk://#diff-26fd1ef9987845a4613d3d4b2e8d0c6d276b4946ba1a6a3e5e77b6e156e224beR1-R206); `onnxscript/nn/_parameter.py`, [[2]](diffhunk://#diff-94ffcabca05fa631e1902fbd4a2e967c00f9db0dca95baae367daef60cd8a7e1R1-R66) * Created `onnxscript/nn/__init__.py` to expose `Module` and `Parameter` as the public API of the new package. (`onnxscript/nn/__init__.py`, [onnxscript/nn/__init__.pyR1-R9](diffhunk://#diff-db47492cd4dbb1a882423614592cb391e31abdec6b5aff85d6f1d2e2716507b4R1-R9)) Core module and parameter functionality: * `Module` class supports automatic registration of parameters and child modules, implements methods for iterating over parameters/modules, and provides `state_dict`/`load_state_dict` for parameter serialization/deserialization, mirroring PyTorch's API. (`onnxscript/nn/_module.py`, [onnxscript/nn/_module.pyR1-R206](diffhunk://#diff-26fd1ef9987845a4613d3d4b2e8d0c6d276b4946ba1a6a3e5e77b6e156e224beR1-R206)) * `Parameter` class subclasses `ir.Value`, allowing direct use in ONNX ops and supporting initialization, realization, and representation of parameter tensors. (`onnxscript/nn/_parameter.py`, [onnxscript/nn/_parameter.pyR1-R66](diffhunk://#diff-94ffcabca05fa631e1902fbd4a2e967c00f9db0dca95baae367daef60cd8a7e1R1-R66)) --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f69cf91 commit 8b79816

File tree

8 files changed

+1419
-59
lines changed

8 files changed

+1419
-59
lines changed

onnxscript/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"script",
66
"graph",
77
"ir",
8+
"nn",
89
"optimizer",
910
"rewriter",
1011
"version_converter",
@@ -127,7 +128,7 @@
127128

128129
# isort: on
129130

130-
from . import ir, optimizer, rewriter, version_converter
131+
from . import ir, nn, optimizer, rewriter, version_converter
131132
from ._internal.builder import GraphBuilder
132133
from ._internal.utils import external_tensor
133134
from ._internal.values import OnnxFunction, TracedOnnxFunction

onnxscript/_internal/builder.py

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def __init__(self, graph: ir.Graph) -> None:
8080

8181
self._op_builder = self.opset("", opset_version)
8282

83-
# Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done.
84-
# The current context is used as a prefix for naming values and nodes.
85-
# This allows us to generate names like "layer1.attention.query"
86-
self._context_stack: list[str] = [""]
83+
# Module scope stack. Each entry is (name, class_name) where name is
84+
# the module attribute name (e.g. "layers.0", "self_attn") and
85+
# class_name is the qualified class name (e.g. "Gemma3DecoderLayer").
86+
self._scope_stack: list[tuple[str, str]] = []
8787

8888
# Cache for constant initializers (scalars and sequences), keyed by (value, dtype).
8989
# This avoids creating duplicate initializers for the same constant
@@ -109,7 +109,7 @@ def initializer(
109109
if name is None:
110110
name = tensor.name
111111
if qualify:
112-
name = self.qualify_name(name)
112+
name = self._qualify_initializer_name(name)
113113
shape = ir.Shape(tensor.shape)
114114
value = ir.Value(
115115
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
@@ -180,24 +180,26 @@ def _adapt_outputs(
180180
self, outputs: int | Sequence[str | ir.Value], op_type: str = ""
181181
) -> Sequence[ir.Value]:
182182
if isinstance(outputs, int):
183+
count = self.graph.num_nodes()
183184
if outputs < 0:
184185
raise ValueError(f"Number of outputs must be non-negative, got {outputs}")
185186
if outputs == 1:
186-
name = f"{op_type}_output" if op_type else "output"
187-
return [ir.Value(name=self.qualify_name(name))]
187+
name = f"{op_type}_{count}" if op_type else f"{count}"
188+
return [ir.Value(name=self._qualify_value_name(name))]
188189
else:
189190
names = [
190-
f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs)
191+
(f"{op_type}_{count}_{i}" if op_type else f"{count}_{i}")
192+
for i in range(outputs)
191193
]
192-
return [ir.Value(name=self.qualify_name(n)) for n in names]
194+
return [ir.Value(name=self._qualify_value_name(n)) for n in names]
193195
adapted_outputs = []
194196
for output in outputs:
195197
if isinstance(output, ir.Value):
196198
if output.name:
197-
output.name = self.qualify_name(output.name)
199+
output.name = self._qualify_value_name(output.name)
198200
adapted_outputs.append(output)
199201
elif isinstance(output, str):
200-
adapted_outputs.append(ir.Value(name=self.qualify_name(output)))
202+
adapted_outputs.append(ir.Value(name=self._qualify_value_name(output)))
201203
else:
202204
raise TypeError("Output type not supported.")
203205
return adapted_outputs
@@ -304,7 +306,7 @@ def call_op(
304306
outputs = kwargs.pop("_outputs", 1)
305307

306308
count = self.graph.num_nodes()
307-
node_name = self.qualify_name(f"{op_type}_node_{count}")
309+
node_name = self._qualify_node_name(f"{op_type}_node_{count}")
308310

309311
output_values = self._adapt_outputs(outputs, op_type)
310312

@@ -322,33 +324,80 @@ def call_op(
322324
version=version,
323325
name=node_name,
324326
)
327+
328+
# Attach scope metadata to the node
329+
node.metadata_props["namespace"] = self._build_namespace()
330+
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
331+
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())
332+
325333
self.add_node(node)
326334

327335
return node.outputs if len(node.outputs) > 1 else node.outputs[0]
328336

329-
def push_module(self, module: str) -> None:
330-
"""Push a new naming context onto the stack (e.g. a layer or module name)."""
331-
current = self.context_name()
332-
if module:
333-
new_context = f"{current}.{module}" if current else module
334-
else:
335-
new_context = current
336-
self._context_stack.append(new_context)
337+
def push_module(self, module: str, class_name: str = "") -> None:
338+
"""Push a new module scope onto the stack.
339+
340+
Args:
341+
module: The attribute name of the module (e.g. ``"layers.0"``).
342+
class_name: The qualified class name (e.g. ``"Gemma3DecoderLayer"``).
343+
"""
344+
self._scope_stack.append((module, class_name))
337345

338346
def pop_module(self) -> None:
339-
"""Pop the most recent naming context off the stack."""
340-
if len(self._context_stack) <= 1:
347+
"""Pop the most recent module scope off the stack."""
348+
if not self._scope_stack:
341349
raise RuntimeError("Cannot pop_module: no module context has been pushed.")
342-
self._context_stack.pop()
350+
self._scope_stack.pop()
351+
352+
def _scope_names(self) -> list[str]:
353+
"""Return the list of module attribute names in the current scope."""
354+
return [name for name, _ in self._scope_stack]
355+
356+
def _scope_classes(self) -> list[str]:
357+
"""Return the list of class names in the current scope."""
358+
return [cls for _, cls in self._scope_stack]
343359

344-
def context_name(self) -> str:
345-
"""Return the current dot-separated naming context prefix."""
346-
return self._context_stack[-1] if self._context_stack else ""
360+
def _scope_name_parts(self) -> list[str]:
361+
"""Return non-empty module names for qualifying names."""
362+
return [name for name, _ in self._scope_stack if name]
347363

348-
def qualify_name(self, name: str) -> str:
349-
"""Prepend the current hierarchical context prefix to the given name."""
350-
prefix = self.context_name()
351-
return f"{prefix}.{name}" if prefix else name
364+
def _qualify_initializer_name(self, name: str) -> str:
365+
"""Prepend the current hierarchical context prefix to the given name.
366+
367+
Uses ``.`` as separator, appropriate for parameter and initializer names.
368+
"""
369+
parts = self._scope_name_parts()
370+
if parts:
371+
return ".".join(parts) + "." + name
372+
return name
373+
374+
def _qualify_value_name(self, name: str) -> str:
375+
"""Qualify a value name with the current scope using ``.`` separator.
376+
377+
The name is prefixed with ``v_`` to distinguish values from parameters.
378+
"""
379+
parts = self._scope_name_parts()
380+
if parts:
381+
return "v_" + ".".join(parts) + "." + name
382+
return f"v_{name}"
383+
384+
def _qualify_node_name(self, name: str) -> str:
385+
"""Qualify a node name with the current scope using ``/`` separator."""
386+
parts = self._scope_name_parts()
387+
if parts:
388+
return "/".join(parts) + "/" + name
389+
return name
390+
391+
def _build_namespace(self) -> str:
392+
"""Build the namespace string for a node.
393+
394+
Each scope entry is formatted as ``name: class_name`` joined by ``/``.
395+
"""
396+
parts = []
397+
for name, cls in self._scope_stack:
398+
if name or cls:
399+
parts.append(f"{name}: {cls}" if cls else name)
400+
return "/".join(parts)
352401

353402

354403
class OpBuilder:

0 commit comments

Comments
 (0)