Skip to content

Commit 864b785

Browse files
Copilotjustinchuby
andauthored
Fix BatchNorm fusion producing invalid ONNX when Conv nodes share weight initializers (#2883)
When a model reuses `Conv2d`+`BatchNorm2d` blocks (same weights called twice), the BatchNorm fusion rewrite creates a new initializer with the same name as the shared weight, overwriting it in the graph's initializer dict. This sets `_is_initializer = False` on the original value still referenced by the second Conv node, producing an invalid model: ``` InvalidArgument: Node input 'conv.weight_1' is not a graph input, initializer, or output of a previous node. ``` Reproducer: ```python class MWE(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 16, kernel_size=3, padding=1) self.bn = nn.BatchNorm2d(16) def forward(self, x): f_1 = self.bn(self.conv(x[:, 0:1])) f_2 = self.bn(self.conv(x[:, 1:2])) return f_1 + f_2 torch.onnx.export(MWE().eval(), ..., optimize=True) # Invalid model ``` ### Changes - **`_fuse_batchnorm.py`**: Added a check in `_FuseBatchNormBase.check()` that rejects the fusion when the inbound node's weight/bias initializers are used by nodes outside the matched pattern. This prevents the overwrite of shared initializers. - **`_fuse_batchnorm_test.py`**: Added regression test with two Conv+BN pairs sharing the same weight and bias initializers. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 12234f8 commit 864b785

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

onnxscript/rewriter/rules/common/_fuse_batchnorm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M
103103
if initializer.is_graph_input():
104104
return check_result.fail(f"{initializer.name} is a graph input.")
105105

106+
# Check that the inbound node's weight and bias initializers are not shared
107+
# with other nodes outside this matched pattern. When the fusion creates new
108+
# initializers with the same name as the original shared weights, it overwrites
109+
# the original initializer in the graph, leaving other nodes that reference the
110+
# original value with an invalid (unregistered) input.
111+
matched_nodes = {inbound_node, batchnorm_node}
112+
inbound_initializers = [inbound_node.inputs[1]]
113+
if len(inbound_node.inputs) > 2:
114+
inbound_initializers.append(inbound_node.inputs[2])
115+
for init_value in inbound_initializers:
116+
for user, _ in init_value.uses():
117+
if user not in matched_nodes:
118+
return check_result.fail(
119+
f"Initializer '{init_value.name}' is used by another node "
120+
f"'{user.name}' outside the matched pattern."
121+
)
122+
106123
return check_result
107124

108125

onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,56 @@ def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self):
311311
bias_names_2 = conv_nodes[1].inputs[2].name
312312
self.assertNotEqual(bias_names_1, bias_names_2)
313313

314+
def test_fuse_batchnorm_skips_shared_weight_initializers(self):
315+
"""Test that BatchNorm fusion is skipped when Conv nodes share weight initializers.
316+
317+
Regression test for https://github.com/microsoft/onnxscript/issues/2382.
318+
When two Conv+BatchNorm pairs share the same weight initializer, fusing the
319+
first pair would overwrite the shared initializer, leaving the second Conv
320+
node with an invalid (unregistered) weight reference.
321+
"""
322+
model_proto = onnx.parser.parse_model("""
323+
< ir_version: 7, opset_import: ["" : 17] >
324+
test_model (float[N, 32, 14, 16] X1, float[N, 32, 14, 16] X2)
325+
=> (float [N, ?, ?, ?] Y)
326+
{
327+
C1 = Conv(X1, W, B)
328+
BN1 = BatchNormalization(C1, gamma, beta, input_mean, input_var)
329+
C2 = Conv(X2, W, B)
330+
BN2 = BatchNormalization(C2, gamma, beta, input_mean, input_var)
331+
Y = Add(BN1, BN2)
332+
}
333+
""")
334+
initializers = [
335+
onnx.numpy_helper.from_array(
336+
np.random.randn(16, 32, 3, 3).astype(np.float32), name="W"
337+
),
338+
onnx.numpy_helper.from_array(np.random.randn(16).astype(np.float32), name="B"),
339+
*self._create_batchnorm_params(size=16),
340+
]
341+
model_proto.graph.initializer.extend(initializers)
342+
onnx.checker.check_model(model_proto, True)
343+
model = ir.serde.deserialize_model(model_proto)
344+
345+
count = _fuse_batchnorm.rules.apply_to_model(model)
346+
347+
# No fusion should be applied because the weight initializer is shared
348+
self.assertEqual(count, 0)
349+
350+
# The model should still be valid after the (non-)optimization
351+
output_model_proto = ir.serde.serialize_model(model)
352+
onnx.checker.check_model(output_model_proto, True)
353+
354+
# Check inference produces correct results
355+
testing.assert_numerically_equal(
356+
model_proto,
357+
model,
358+
(
359+
np.random.rand(1, 32, 14, 16).astype(np.float32),
360+
np.random.rand(1, 32, 14, 16).astype(np.float32),
361+
),
362+
)
363+
314364

315365
if __name__ == "__main__":
316366
unittest.main()

0 commit comments

Comments
 (0)