Skip to content

Commit 70b5298

Browse files
committed
Run lint
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent cfc0a47 commit 70b5298

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,18 @@ def test_fusion(self):
362362
assert_allclose(outputs3, source_model_outputs)
363363

364364

365-
@parameterized.parameterized_class([
366-
{"with_past": True, "transpose_first": True},
367-
{"with_past": True, "transpose_first": False},
368-
{"with_past": False, "transpose_first": True},
369-
{"with_past": False, "transpose_first": False},
370-
])
365+
@parameterized.parameterized_class(
366+
[
367+
{"with_past": True, "transpose_first": True},
368+
{"with_past": True, "transpose_first": False},
369+
{"with_past": False, "transpose_first": True},
370+
{"with_past": False, "transpose_first": False},
371+
]
372+
)
371373
class GemmaGQAFusionTest(unittest.TestCase):
372374
with_past = True
373375
transpose_first = True
376+
374377
def __init__(self, *args, **kwargs):
375378
super().__init__(*args, **kwargs)
376379

@@ -485,11 +488,15 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
485488
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
486489
query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
487490
)
488-
query_BHSDh_normalized = op.Transpose(query_BSHDh_normalized, perm=[0, 2, 1, 3])
491+
query_BHSDh_normalized = op.Transpose(
492+
query_BSHDh_normalized, perm=[0, 2, 1, 3]
493+
)
489494
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
490495
key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
491496
)
492-
key_BHkvSDh_normalized = op.Transpose(key_BSHkvDh_normalized, perm=[0, 2, 1, 3])
497+
key_BHkvSDh_normalized = op.Transpose(
498+
key_BSHkvDh_normalized, perm=[0, 2, 1, 3]
499+
)
493500

494501
value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
495502
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])

0 commit comments

Comments
 (0)