Skip to content

Commit 5c5277f

Browse files
Copilotjustinchuby
andauthored
Fix BatchNorm fusion creating invalid model when Conv nodes share weight initializers
Add a check in _FuseBatchNormBase.check() to verify that the inbound Conv/ConvTranspose/Gemm node's weight and bias initializers are not shared with other nodes outside the matched pattern. When two Conv+BatchNorm pairs share the same weight initializer, fusing the first pair overwrites the shared initializer in the graph with fused values, leaving the second Conv node with an invalid (unregistered) weight reference, producing an invalid ONNX model. Fixes #2382 Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/10e4a5fd-e010-48dc-8a29-991b7b0a6ca7 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent 58d2b41 commit 5c5277f

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

onnxscript/rewriter/rules/common/_fuse_batchnorm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M
103103
if initializer.is_graph_input():
104104
return check_result.fail(f"{initializer.name} is a graph input.")
105105

106+
# Check that the inbound node's weight and bias initializers are not shared
107+
# with other nodes outside this matched pattern. When the fusion creates new
108+
# initializers with the same name as the original shared weights, it overwrites
109+
# the original initializer in the graph, leaving other nodes that reference the
110+
# original value with an invalid (unregistered) input.
111+
matched_nodes = {inbound_node, batchnorm_node}
112+
inbound_initializers = [inbound_node.inputs[1]]
113+
if len(inbound_node.inputs) > 2:
114+
inbound_initializers.append(inbound_node.inputs[2])
115+
for init_value in inbound_initializers:
116+
for user, _ in init_value.uses():
117+
if user not in matched_nodes:
118+
return check_result.fail(
119+
f"Initializer '{init_value.name}' is used by another node "
120+
f"'{user.name}' outside the matched pattern."
121+
)
122+
106123
return check_result
107124

108125

onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,56 @@ def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self):
311311
bias_names_2 = conv_nodes[1].inputs[2].name
312312
self.assertNotEqual(bias_names_1, bias_names_2)
313313

314+
def test_fuse_batchnorm_skips_shared_weight_initializers(self):
315+
"""Test that BatchNorm fusion is skipped when Conv nodes share weight initializers.
316+
317+
Regression test for https://github.com/microsoft/onnxscript/issues/2382.
318+
When two Conv+BatchNorm pairs share the same weight initializer, fusing the
319+
first pair would overwrite the shared initializer, leaving the second Conv
320+
node with an invalid (unregistered) weight reference.
321+
"""
322+
model_proto = onnx.parser.parse_model("""
323+
< ir_version: 7, opset_import: ["" : 17] >
324+
test_model (float[N, 32, 14, 16] X1, float[N, 32, 14, 16] X2)
325+
=> (float [N, ?, ?, ?] Y)
326+
{
327+
C1 = Conv(X1, W, B)
328+
BN1 = BatchNormalization(C1, gamma, beta, input_mean, input_var)
329+
C2 = Conv(X2, W, B)
330+
BN2 = BatchNormalization(C2, gamma, beta, input_mean, input_var)
331+
Y = Add(BN1, BN2)
332+
}
333+
""")
334+
initializers = [
335+
onnx.numpy_helper.from_array(
336+
np.random.randn(16, 32, 3, 3).astype(np.float32), name="W"
337+
),
338+
onnx.numpy_helper.from_array(np.random.randn(16).astype(np.float32), name="B"),
339+
*self._create_batchnorm_params(size=16),
340+
]
341+
model_proto.graph.initializer.extend(initializers)
342+
onnx.checker.check_model(model_proto, True)
343+
model = ir.serde.deserialize_model(model_proto)
344+
345+
count = _fuse_batchnorm.rules.apply_to_model(model)
346+
347+
# No fusion should be applied because the weight initializer is shared
348+
self.assertEqual(count, 0)
349+
350+
# The model should still be valid after the (non-)optimization
351+
output_model_proto = ir.serde.serialize_model(model)
352+
onnx.checker.check_model(output_model_proto, True)
353+
354+
# Check inference produces correct results
355+
testing.assert_numerically_equal(
356+
model_proto,
357+
model,
358+
(
359+
np.random.rand(1, 32, 14, 16).astype(np.float32),
360+
np.random.rand(1, 32, 14, 16).astype(np.float32),
361+
),
362+
)
363+
314364

315365
if __name__ == "__main__":
316366
unittest.main()

0 commit comments

Comments
 (0)