Skip to content

Commit 3c55666

Browse files
committed
fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1
1 parent 1077da7 commit 3c55666

File tree

2 files changed

+107
-28
lines changed

2 files changed

+107
-28
lines changed

onnxscript/rewriter/rules/common/_fuse_batchnorm.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT License.
33
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
44
- BatchNormalization ∘ Conv -> Conv
5-
- BatchNormalization ∘ ConvTranpose -> ConvTranpose
5+
- BatchNormalization ∘ ConvTranspose -> ConvTranspose
66
- BatchNormalization ∘ Gemm -> Gemm
77
88
Approach:
@@ -14,7 +14,7 @@
1414
- B_fused = (B - μ) * (gamma / std) + β
1515
"""
1616

17-
from abc import ABC, abstractmethod
17+
from abc import ABC
1818
from typing import ClassVar, Mapping
1919

2020
import numpy as np
@@ -33,9 +33,18 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra
3333
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
3434
"""Interface for BatchNormalization nodes fusion."""
3535

36-
@abstractmethod
3736
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
3837
"""Return the axis along which BatchNorm scale should be broadcasted."""
38+
raise NotImplementedError()
39+
40+
def _scale_weights(
41+
self,
42+
weights: np.ndarray,
43+
scale_factor: np.ndarray,
44+
attributes: Mapping[str, ir.Attr],
45+
) -> np.ndarray:
46+
axis = self.get_filters_axis(attributes)
47+
return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
3948

4049
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
4150
batchnorm_node = batchnorm_out.producer()
@@ -56,10 +65,8 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
5665
inbound_node = inbound_out.producer()
5766
weights = inbound_node.inputs[1].const_value.numpy()
5867

59-
# Reshape scale factor so it is broadcastable
60-
axis = self.get_filters_axis(inbound_node.attributes)
6168
fused_weights = ir.tensor(
62-
weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
69+
self._scale_weights(weights, scale_factor, inbound_node.attributes)
6370
)
6471

6572
# Update bias
@@ -127,8 +134,26 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
127134

128135
op_type: ClassVar = "ConvTranspose"
129136

130-
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
131-
return 1
137+
def _scale_weights(
138+
self,
139+
weights: np.ndarray,
140+
scale_factor: np.ndarray,
141+
attributes: Mapping[str, ir.Attr],
142+
) -> np.ndarray:
143+
# ConvTranspose weight: (in_channels, out_channels/group, *kernel)
144+
# Reshape weights: [in_channels, out_channels/group, *kernel] → [group, in_channels/group, out_channels/group, *kernel]
145+
in_channels = weights.shape[0]
146+
out_channels_per_group = weights.shape[1]
147+
kernel_shape = weights.shape[2:]
148+
group = attributes.get("group", ir.AttrInt64("group", 1)).as_int()
149+
w = weights.reshape(group, in_channels // group, out_channels_per_group, *kernel_shape)
150+
151+
# Per group scale_factor (out_channels,) -> (group, out_channels/group) -> (group, 1, out_channels/group, 1, ..., 1)
152+
s = scale_factor.reshape((group, out_channels_per_group) + (1,) * len(kernel_shape))
153+
# insert in_channels/group axis -> (group, 1, out_channels/group, *ones)
154+
s = s[:, None, ...]
155+
156+
return (w * s).reshape(weights.shape)
132157

133158
def pattern(self, op, x):
134159
return op.BatchNormalization(
@@ -137,6 +162,25 @@ def pattern(self, op, x):
137162
_outputs=["batchnorm_out"],
138163
)
139164

165+
def check(self, context, x, inbound_out, batchnorm_out):
166+
check_result = super().check(context, x, inbound_out, batchnorm_out)
167+
if not check_result:
168+
return check_result
169+
170+
inbound_node = inbound_out.producer()
171+
172+
in_channels = inbound_node.inputs[1].const_value.numpy().shape[0]
173+
group = inbound_node.attributes.get("group", ir.AttrInt64("group", 1)).as_int()
174+
175+
# Check that in_channels is divisible by group as ONNX checker allows it
176+
# But this is invalid case
177+
if in_channels % group != 0:
178+
return check_result.fail(
179+
f"ConvTranspose in_channels ({in_channels}) is not divisible by group ({group})."
180+
)
181+
182+
return check_result
183+
140184

141185
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
142186
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""

onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import unittest
44

55
import numpy as np
6-
import onnx.checker
7-
import onnx.parser
6+
import onnx
87
import parameterized
98

109
from onnxscript import ir
@@ -31,29 +30,33 @@ def _create_batchnorm_params(self, size: int):
3130

3231
@parameterized.parameterized.expand(
3332
[
34-
("bias_false", False),
35-
("bias_true", True),
33+
("bias_false_group1", False, 1),
34+
("bias_true_group1", True, 1),
35+
("bias_false_group4", False, 4),
36+
("bias_true_group4", True, 4),
3637
]
3738
)
38-
def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
39+
def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool, group: int):
40+
# ConvTranspose weight: [in_channels, out_channels/group, kH, kW]
41+
out_channels = 64 * group
3942
convtranspose_inputs = "X, W"
4043
parameters = (
41-
"float[32, 64, 3, 3] W, "
42-
"float[64] gamma, "
43-
"float[64] beta, "
44-
"float[64] input_mean, "
45-
"float[64] input_var"
44+
f"float[32, 64, 3, 3] W, "
45+
f"float[{out_channels}] gamma, "
46+
f"float[{out_channels}] beta, "
47+
f"float[{out_channels}] input_mean, "
48+
f"float[{out_channels}] input_var"
4649
)
4750
if convtranspose_bias:
48-
parameters += ", float[64] B"
51+
parameters += f", float[{out_channels}] B"
4952
convtranspose_inputs += ", B"
5053

5154
model_proto = onnx.parser.parse_model(f"""
5255
< ir_version: 7, opset_import: ["" : 17] >
5356
test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y)
5457
<{parameters}>
5558
{{
56-
X1 = ConvTranspose({convtranspose_inputs})
59+
X1 = ConvTranspose<group={group}>({convtranspose_inputs})
5760
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
5861
}}
5962
""")
@@ -62,11 +65,13 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
6265
onnx.numpy_helper.from_array(
6366
np.random.randn(32, 64, 3, 3).astype(np.float32), name="W"
6467
),
65-
*self._create_batchnorm_params(size=64),
68+
*self._create_batchnorm_params(size=out_channels),
6669
]
6770
if convtranspose_bias:
6871
initializers.append(
69-
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B")
72+
onnx.numpy_helper.from_array(
73+
np.random.randn(out_channels).astype(np.float32), name="B"
74+
)
7075
)
7176
model_proto.graph.initializer.extend(initializers)
7277

@@ -90,14 +95,18 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
9095

9196
@parameterized.parameterized.expand(
9297
[
93-
("bias_false", False),
94-
("bias_true", True),
98+
("bias_false_group1", False, 1),
99+
("bias_true_group1", True, 1),
100+
("bias_false_group2", False, 2),
101+
("bias_true_group2", True, 2),
95102
]
96103
)
97-
def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool):
104+
def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool, group: int):
105+
# Conv weight: [out_channels, in_channels/group, kH, kW]
106+
in_channels_per_group = 32 // group
98107
conv_inputs = "X, W"
99108
parameters = (
100-
"float[64, 32, 3, 3] W, "
109+
f"float[64, {in_channels_per_group}, 3, 3] W, "
101110
"float[64] gamma, "
102111
"float[64] beta, "
103112
"float[64] input_mean, "
@@ -112,14 +121,14 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool):
112121
test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y)
113122
<{parameters}>
114123
{{
115-
X1 = Conv({conv_inputs})
124+
X1 = Conv<group={group}>({conv_inputs})
116125
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
117126
}}
118127
""")
119128
# Add initializers
120129
initializers = [
121130
onnx.numpy_helper.from_array(
122-
np.random.randn(64, 32, 3, 3).astype(np.float32), name="W"
131+
np.random.randn(64, in_channels_per_group, 3, 3).astype(np.float32), name="W"
123132
),
124133
*self._create_batchnorm_params(size=64),
125134
]
@@ -211,6 +220,32 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int):
211220
output_model_proto = ir.serde.serialize_model(model)
212221
onnx.checker.check_model(output_model_proto, True)
213222

223+
def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped(self):
224+
"""Fusion is skipped when in_channels is not divisible by group (semantically invalid model)."""
225+
# in_channels=32 is not divisible by group=3, the ONNX checker won't catch this.
226+
model_proto = onnx.parser.parse_model("""
227+
< ir_version: 7, opset_import: ["" : 17] >
228+
test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y)
229+
<float[32, 64, 3, 3] W,
230+
float[192] gamma, float[192] beta, float[192] input_mean, float[192] input_var>
231+
{
232+
X1 = ConvTranspose<group=3>(X, W)
233+
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
234+
}
235+
""")
236+
initializers = [
237+
onnx.numpy_helper.from_array(
238+
np.random.randn(32, 64, 3, 3).astype(np.float32), name="W"
239+
),
240+
*self._create_batchnorm_params(size=192),
241+
]
242+
model_proto.graph.initializer.extend(initializers)
243+
model = ir.serde.deserialize_model(model_proto)
244+
count = _fuse_batchnorm.rules.apply_to_model(model)
245+
246+
# Fusion must be skipped, applying it would crash on the invalid dimensions.
247+
self.assertEqual(count, 0)
248+
214249
def test_fuse_batchnorm_non_initializers(self):
215250
model_proto = onnx.parser.parse_model("""
216251
< ir_version: 7, opset_import: ["" : 17] >

0 commit comments

Comments
 (0)