@@ -362,15 +362,18 @@ def test_fusion(self):
362362 assert_allclose (outputs3 , source_model_outputs )
363363
364364
365- @parameterized .parameterized_class ([
366- {"with_past" : True , "transpose_first" : True },
367- {"with_past" : True , "transpose_first" : False },
368- {"with_past" : False , "transpose_first" : True },
369- {"with_past" : False , "transpose_first" : False },
370- ])
365+ @parameterized .parameterized_class (
366+ [
367+ {"with_past" : True , "transpose_first" : True },
368+ {"with_past" : True , "transpose_first" : False },
369+ {"with_past" : False , "transpose_first" : True },
370+ {"with_past" : False , "transpose_first" : False },
371+ ]
372+ )
371373class GemmaGQAFusionTest (unittest .TestCase ):
372374 with_past = True
373375 transpose_first = True
376+
374377 def __init__ (self , * args , ** kwargs ):
375378 super ().__init__ (* args , ** kwargs )
376379
@@ -485,11 +488,15 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
485488 query_BSHDh_normalized = op .SimplifiedLayerNormalization (
486489 query_BSHDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
487490 )
488- query_BHSDh_normalized = op .Transpose (query_BSHDh_normalized , perm = [0 , 2 , 1 , 3 ])
491+ query_BHSDh_normalized = op .Transpose (
492+ query_BSHDh_normalized , perm = [0 , 2 , 1 , 3 ]
493+ )
489494 key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
490495 key_BSHkvDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
491496 )
492- key_BHkvSDh_normalized = op .Transpose (key_BSHkvDh_normalized , perm = [0 , 2 , 1 , 3 ])
497+ key_BHkvSDh_normalized = op .Transpose (
498+ key_BSHkvDh_normalized , perm = [0 , 2 , 1 , 3 ]
499+ )
493500
494501 value_BSHkvDh = op .Reshape (value , shape_BSHkvDh )
495502 value_BHkvSDh = op .Transpose (value_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
0 commit comments