33import unittest
44
55import numpy as np
6- import onnx .checker
7- import onnx .parser
6+ import onnx
87import parameterized
98
109from onnxscript import ir
@@ -31,29 +30,33 @@ def _create_batchnorm_params(self, size: int):
3130
3231 @parameterized .parameterized .expand (
3332 [
34- ("bias_false" , False ),
35- ("bias_true" , True ),
33+ ("bias_false_group1" , False , 1 ),
34+ ("bias_true_group1" , True , 1 ),
35+ ("bias_false_group4" , False , 4 ),
36+ ("bias_true_group4" , True , 4 ),
3637 ]
3738 )
38- def test_fuse_batchnorm_convtranspose (self , _ : str , convtranspose_bias : bool ):
39+ def test_fuse_batchnorm_convtranspose (self , _ : str , convtranspose_bias : bool , group : int ):
40+ # ConvTranspose weight: [in_channels, out_channels/group, kH, kW]
41+ out_channels = 64 * group
3942 convtranspose_inputs = "X, W"
4043 parameters = (
41- "float[32, 64, 3, 3] W, "
42- "float[64 ] gamma, "
43- "float[64 ] beta, "
44- "float[64 ] input_mean, "
45- "float[64 ] input_var"
44+ f "float[32, 64, 3, 3] W, "
45+ f "float[{ out_channels } ] gamma, "
46+ f "float[{ out_channels } ] beta, "
47+ f "float[{ out_channels } ] input_mean, "
48+ f "float[{ out_channels } ] input_var"
4649 )
4750 if convtranspose_bias :
48- parameters += ", float[64 ] B"
51+ parameters += f ", float[{ out_channels } ] B"
4952 convtranspose_inputs += ", B"
5053
5154 model_proto = onnx .parser .parse_model (f"""
5255 < ir_version: 7, opset_import: ["" : 17] >
5356 test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y)
5457 <{ parameters } >
5558 {{
56- X1 = ConvTranspose({ convtranspose_inputs } )
59+ X1 = ConvTranspose<group= { group } > ({ convtranspose_inputs } )
5760 Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
5861 }}
5962 """ )
@@ -62,11 +65,13 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
6265 onnx .numpy_helper .from_array (
6366 np .random .randn (32 , 64 , 3 , 3 ).astype (np .float32 ), name = "W"
6467 ),
65- * self ._create_batchnorm_params (size = 64 ),
68+ * self ._create_batchnorm_params (size = out_channels ),
6669 ]
6770 if convtranspose_bias :
6871 initializers .append (
69- onnx .numpy_helper .from_array (np .random .randn (64 ).astype (np .float32 ), name = "B" )
72+ onnx .numpy_helper .from_array (
73+ np .random .randn (out_channels ).astype (np .float32 ), name = "B"
74+ )
7075 )
7176 model_proto .graph .initializer .extend (initializers )
7277
@@ -90,14 +95,18 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
9095
9196 @parameterized .parameterized .expand (
9297 [
93- ("bias_false" , False ),
94- ("bias_true" , True ),
98+ ("bias_false_group1" , False , 1 ),
99+ ("bias_true_group1" , True , 1 ),
100+ ("bias_false_group2" , False , 2 ),
101+ ("bias_true_group2" , True , 2 ),
95102 ]
96103 )
97- def test_fuse_batchnorm_conv (self , _ : str , conv_bias : bool ):
104+ def test_fuse_batchnorm_conv (self , _ : str , conv_bias : bool , group : int ):
105+ # Conv weight: [out_channels, in_channels/group, kH, kW]
106+ in_channels_per_group = 32 // group
98107 conv_inputs = "X, W"
99108 parameters = (
100- "float[64, 32 , 3, 3] W, "
109+ f "float[64, { in_channels_per_group } , 3, 3] W, "
101110 "float[64] gamma, "
102111 "float[64] beta, "
103112 "float[64] input_mean, "
@@ -112,14 +121,14 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool):
112121 test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y)
113122 <{ parameters } >
114123 {{
115- X1 = Conv({ conv_inputs } )
124+ X1 = Conv<group= { group } > ({ conv_inputs } )
116125 Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
117126 }}
118127 """ )
119128 # Add initializers
120129 initializers = [
121130 onnx .numpy_helper .from_array (
122- np .random .randn (64 , 32 , 3 , 3 ).astype (np .float32 ), name = "W"
131+ np .random .randn (64 , in_channels_per_group , 3 , 3 ).astype (np .float32 ), name = "W"
123132 ),
124133 * self ._create_batchnorm_params (size = 64 ),
125134 ]
@@ -211,6 +220,32 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int):
211220 output_model_proto = ir .serde .serialize_model (model )
212221 onnx .checker .check_model (output_model_proto , True )
213222
223+ def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped (self ):
224+ """Fusion is skipped when in_channels is not divisible by group (semantically invalid model)."""
225+ # in_channels=32 is not divisible by group=3, the ONNX checker won't catch this.
226+ model_proto = onnx .parser .parse_model ("""
227+ < ir_version: 7, opset_import: ["" : 17] >
228+ test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y)
229+ <float[32, 64, 3, 3] W,
230+ float[192] gamma, float[192] beta, float[192] input_mean, float[192] input_var>
231+ {
232+ X1 = ConvTranspose<group=3>(X, W)
233+ Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
234+ }
235+ """ )
236+ initializers = [
237+ onnx .numpy_helper .from_array (
238+ np .random .randn (32 , 64 , 3 , 3 ).astype (np .float32 ), name = "W"
239+ ),
240+ * self ._create_batchnorm_params (size = 192 ),
241+ ]
242+ model_proto .graph .initializer .extend (initializers )
243+ model = ir .serde .deserialize_model (model_proto )
244+ count = _fuse_batchnorm .rules .apply_to_model (model )
245+
246+ # Fusion must be skipped, applying it would crash on the invalid dimensions.
247+ self .assertEqual (count , 0 )
248+
214249 def test_fuse_batchnorm_non_initializers (self ):
215250 model_proto = onnx .parser .parse_model ("""
216251 < ir_version: 7, opset_import: ["" : 17] >
0 commit comments