@@ -118,14 +118,14 @@ def test_all_three_biases_fused(self):
118118 model = self ._build (
119119 _mha_all_biases ,
120120 input_types = [
121- FLOAT [_B , _S , _D ],
122- FLOAT [_B , _S , _Dk ],
123- FLOAT [_B , _S , _Dv ],
121+ FLOAT ["B" , "S" , _D ],
122+ FLOAT ["B" , "S" , _Dk ],
123+ FLOAT ["B" , "S" , _Dv ],
124124 FLOAT [_D ],
125125 FLOAT [_Dk ],
126126 FLOAT [_Dv ],
127127 ],
128- output_types = [FLOAT [_B , _S , _D ]],
128+ output_types = [FLOAT ["B" , "S" , _D ]],
129129 )
130130 inputs = {
131131 "query_matmul" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
@@ -145,12 +145,12 @@ def test_only_q_bias(self):
145145 model = self ._build (
146146 _mha_q_bias_only ,
147147 input_types = [
148- FLOAT [_B , _S , _D ],
149- FLOAT [_B , _S , _Dk ],
150- FLOAT [_B , _S , _Dv ],
148+ FLOAT ["B" , "S" , _D ],
149+ FLOAT ["B" , "S" , _Dk ],
150+ FLOAT ["B" , "S" , _Dv ],
151151 FLOAT [_D ],
152152 ],
153- output_types = [FLOAT [_B , _S , _D ]],
153+ output_types = [FLOAT ["B" , "S" , _D ]],
154154 )
155155 inputs = {
156156 "query_matmul" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
@@ -167,12 +167,12 @@ def test_only_k_bias(self):
167167 model = self ._build (
168168 _mha_k_bias_only ,
169169 input_types = [
170- FLOAT [_B , _S , _D ],
171- FLOAT [_B , _S , _Dk ],
172- FLOAT [_B , _S , _Dv ],
170+ FLOAT ["B" , "S" , _D ],
171+ FLOAT ["B" , "S" , _Dk ],
172+ FLOAT ["B" , "S" , _Dv ],
173173 FLOAT [_Dk ],
174174 ],
175- output_types = [FLOAT [_B , _S , _D ]],
175+ output_types = [FLOAT ["B" , "S" , _D ]],
176176 )
177177 inputs = {
178178 "query_matmul" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
@@ -188,12 +188,12 @@ def test_only_v_bias(self):
188188 model = self ._build (
189189 _mha_v_bias_only ,
190190 input_types = [
191- FLOAT [_B , _S , _D ],
192- FLOAT [_B , _S , _Dk ],
193- FLOAT [_B , _S , _Dv ],
191+ FLOAT ["B" , "S" , _D ],
192+ FLOAT ["B" , "S" , _Dk ],
193+ FLOAT ["B" , "S" , _Dv ],
194194 FLOAT [_Dv ],
195195 ],
196- output_types = [FLOAT [_B , _S , _D ]],
196+ output_types = [FLOAT ["B" , "S" , _D ]],
197197 )
198198 inputs = {
199199 "query_matmul" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
@@ -209,13 +209,13 @@ def test_q_and_k_bias_only(self):
209209 model = self ._build (
210210 _mha_qk_biases ,
211211 input_types = [
212- FLOAT [_B , _S , _D ],
213- FLOAT [_B , _S , _Dk ],
214- FLOAT [_B , _S , _Dv ],
212+ FLOAT ["B" , "S" , _D ],
213+ FLOAT ["B" , "S" , _Dk ],
214+ FLOAT ["B" , "S" , _Dv ],
215215 FLOAT [_D ],
216216 FLOAT [_Dk ],
217217 ],
218- output_types = [FLOAT [_B , _S , _D ]],
218+ output_types = [FLOAT ["B" , "S" , _D ]],
219219 )
220220 inputs = {
221221 "query_matmul" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
@@ -232,8 +232,8 @@ def test_no_biases_no_fusion(self):
232232 """No bias Adds at all → rule should not apply."""
233233 model = self ._build (
234234 _mha_no_biases ,
235- input_types = [FLOAT [_B , _S , _D ], FLOAT [_B , _S , _Dk ], FLOAT [_B , _S , _Dv ]],
236- output_types = [FLOAT [_B , _S , _D ]],
235+ input_types = [FLOAT ["B" , "S" , _D ], FLOAT ["B" , "S" , _Dk ], FLOAT ["B" , "S" , _Dv ]],
236+ output_types = [FLOAT ["B" , "S" , _D ]],
237237 )
238238 count = self ._apply (model )
239239 self .assertEqual (count , 0 )
@@ -244,12 +244,12 @@ def test_int32_dtype_no_fusion(self):
244244 model = self ._build (
245245 _mha_int32_with_bias ,
246246 input_types = [
247- INT32 [_B , _S , _D ],
248- INT32 [_B , _S , _Dk ],
249- INT32 [_B , _S , _Dv ],
247+ INT32 ["B" , "S" , _D ],
248+ INT32 ["B" , "S" , _Dk ],
249+ INT32 ["B" , "S" , _Dv ],
250250 INT32 [_D ],
251251 ],
252- output_types = [INT32 [_B , _S , _D ]],
252+ output_types = [INT32 ["B" , "S" , _D ]],
253253 )
254254 count = self ._apply (model )
255255 self .assertEqual (count , 0 )
@@ -259,12 +259,12 @@ def test_shape_mismatch_no_fusion(self):
259259 model = self ._build (
260260 _mha_rank2_query_with_bias ,
261261 input_types = [
262- FLOAT [_S , _D ],
263- FLOAT [_B , _S , _Dk ],
264- FLOAT [_B , _S , _Dv ],
262+ FLOAT ["S" , _D ],
263+ FLOAT ["B" , "S" , _Dk ],
264+ FLOAT ["B" , "S" , _Dv ],
265265 FLOAT [_D ],
266266 ],
267- output_types = [FLOAT [_B , _S , _D ]],
267+ output_types = [FLOAT ["B" , "S" , _D ]],
268268 )
269269 count = self ._apply (model )
270270 self .assertEqual (count , 0 )
0 commit comments