Skip to content

Commit 1fdfb1b

Browse files
Copilotjustinchuby
andauthored
Update nn.Sequential signature to match PyTorch (*modules varargs) (#2827)
`nn.Sequential.__init__` accepted a `list[Module] | None` parameter, diverging from PyTorch's `torch.nn.Sequential` which takes variadic positional args. ## Changes - **`nn/_sequential.py`**: Added `__init__(self, *modules: Module)` override that accepts variadic args and delegates to `ModuleList.__init__`; updated docstring example. - **`nn/_module_test.py`**: Migrated all `Sequential([...])` call sites to `Sequential(...)`. ## Before / After ```python # Before seq = Sequential([SiLU(), Linear(4, 4)]) seq_empty = Sequential([]) # After — matches torch.nn.Sequential seq = Sequential(SiLU(), Linear(4, 4)) seq_empty = Sequential() ``` <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>Update nn.Sequential signature to match pytorch</issue_title> > <issue_description>`def __init__(self, modules: list[Module] | None = None) -> None` should become `def __init__(self, *modules: Module = None) -> None:` to match pytorch signature.</issue_description> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #2826 <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxscript/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent ef2bc22 commit 1fdfb1b

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

onnxscript/nn/_module_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def forward(self, op, x):
854854
return op.Add(x, op.Constant(value_float=1.0))
855855

856856
graph, op, x = self._make_input()
857-
seq = Sequential([AddOne(), AddOne(), AddOne()])
857+
seq = Sequential(AddOne(), AddOne(), AddOne())
858858
result = seq(op, x)
859859

860860
self.assertIsInstance(result, ir.Value)
@@ -869,7 +869,7 @@ def forward(self, op, x):
869869
return op.Identity(x)
870870

871871
_, op, x = self._make_input()
872-
seq = Sequential([PassThrough()])
872+
seq = Sequential(PassThrough())
873873
result = seq(op, x)
874874
self.assertIsInstance(result, ir.Value)
875875

@@ -899,7 +899,7 @@ def forward(self, op, pair):
899899
return op.Add(a, b)
900900

901901
graph, op, x = self._make_input()
902-
seq = Sequential([SplitTwo(), UnpackAndAdd()])
902+
seq = Sequential(SplitTwo(), UnpackAndAdd())
903903
result = seq(op, x)
904904

905905
self.assertIsInstance(result, ir.Value)
@@ -920,7 +920,7 @@ def forward(self, op, pair):
920920
return op.Add(a, b)
921921

922922
_, op, x = self._make_input()
923-
seq = Sequential([SplitTwoList(), UnpackAndAdd()])
923+
seq = Sequential(SplitTwoList(), UnpackAndAdd())
924924
result = seq(op, x)
925925
self.assertIsInstance(result, ir.Value)
926926

@@ -938,7 +938,7 @@ def forward(self, op, pair):
938938
return pair
939939

940940
_, op, x = self._make_input()
941-
seq = Sequential([ReturnPair(), TupleIdentity()])
941+
seq = Sequential(ReturnPair(), TupleIdentity())
942942
result = seq(op, x)
943943
self.assertIsInstance(result, tuple)
944944
self.assertEqual(len(result), 2)
@@ -951,7 +951,7 @@ def forward(self, op, x):
951951
return (op.Identity(x), op.Identity(x))
952952

953953
_, op, x = self._make_input()
954-
seq = Sequential([ReturnPair()])
954+
seq = Sequential(ReturnPair())
955955
result = seq(op, x)
956956
self.assertIsInstance(result, tuple)
957957
self.assertEqual(len(result), 2)
@@ -968,7 +968,7 @@ def forward(self, op, x):
968968
return (op.Identity(x), op.Identity(x), op.Identity(x))
969969

970970
_, op, x = self._make_input()
971-
seq = Sequential([Identity(), SplitThree()])
971+
seq = Sequential(Identity(), SplitThree())
972972
result = seq(op, x)
973973
self.assertIsInstance(result, tuple)
974974
self.assertEqual(len(result), 3)
@@ -987,7 +987,7 @@ def forward(self, op, x):
987987

988988
_, op = _create_graph_and_op()
989989
accept = AcceptNone()
990-
seq = Sequential([ReturnNone(), accept])
990+
seq = Sequential(ReturnNone(), accept)
991991
result = seq(op, "anything")
992992
self.assertIsNone(result)
993993
self.assertIsNone(accept.received)
@@ -1008,7 +1008,7 @@ def forward(self, op, x):
10081008
class Model(Module):
10091009
def __init__(self):
10101010
super().__init__("model")
1011-
self.layers = Sequential([Linear(4, 4), Linear(4, 4)])
1011+
self.layers = Sequential(Linear(4, 4), Linear(4, 4))
10121012

10131013
def forward(self, op, x):
10141014
return self.layers(op, x)
@@ -1035,7 +1035,7 @@ def __init__(self, size):
10351035
def forward(self, op, x):
10361036
return op.MatMul(x, op.Transpose(self.weight, perm=[1, 0]))
10371037

1038-
seq = Sequential([SiLU(), Linear(4)])
1038+
seq = Sequential(SiLU(), Linear(4))
10391039
named = dict(seq.named_parameters())
10401040
# SiLU at index 0 has no params; Linear at index 1 has weight
10411041
self.assertIn("1.weight", named)
@@ -1061,7 +1061,7 @@ def forward(self, op, x):
10611061
class Model(Module):
10621062
def __init__(self):
10631063
super().__init__("model")
1064-
self.blocks = Sequential([])
1064+
self.blocks = Sequential()
10651065
# Append AFTER __setattr__ has set Sequential._name = "blocks"
10661066
self.blocks.append(Linear(4))
10671067
self.blocks.append(Linear(4))

onnxscript/nn/_sequential.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@ def forward(self, op, x):
2424
2525
# Produces parameter names: "mod.0.weight", "mod.0.bias"
2626
# SiLU at index 0 has no parameters.
27-
mod = Sequential([SiLU(), Linear(4, 4)])
27+
mod = Sequential(SiLU(), Linear(4, 4))
2828
2929
# Calling mod(op, x) is equivalent to:
3030
# x = silu(op, x)
3131
# x = linear(op, x)
3232
"""
3333

34+
def __init__(self, *modules: _module_list.Module) -> None:
35+
super().__init__(modules)
36+
3437
def _set_name(self, name: str) -> None:
3538
"""Set this container's name. Children keep simple ``"0"``, ``"1"`` names.
3639

0 commit comments

Comments
 (0)