99import unittest
1010
1111import numpy
12+ from parameterized import parameterized
1213
1314import onnxscript .ir as ir
1415import onnxscript .optimizer
2223S = 8 # sequence length
2324H = 128 # head size
2425SCALE_FACTOR = math .sqrt (H )
26+ MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
2527SQRT_SCALE_FACTOR = math .sqrt (SCALE_FACTOR )
28+ SQRT_MUL_SCALE_FACTOR = math .sqrt (MUL_SCALE_FACTOR )
2629
2730
2831@script ()
@@ -38,16 +41,55 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
3841 return attn_output
3942
4043
41- class _MaskedPreDivSDPATestCase :
44+ @script ()
45+ def _masked_pre_mul_sdpa_script (query , key , value , mask ):
46+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
47+ multiplier = op .Constant (value_float = SQRT_MUL_SCALE_FACTOR )
48+ scaled_query = op .Mul (query , multiplier )
49+ scaled_key = op .Mul (key_transposed , multiplier )
50+ attn_score = op .MatMul (scaled_query , scaled_key )
51+ masked_attn_score = op .Add (attn_score , mask )
52+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
53+ attn_output = op .MatMul (attn_weight , value )
54+ return attn_output
55+
56+
57+ @script ()
58+ def _masked_post_div_sdpa_script (query , key , value , mask ):
59+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
60+ divisor = op .Constant (value_float = SCALE_FACTOR )
61+ attn_score = op .MatMul (query , key_transposed )
62+ scaled_attn_score = op .Div (attn_score , divisor )
63+ masked_attn_score = op .Add (scaled_attn_score , mask )
64+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
65+ attn_output = op .MatMul (attn_weight , value )
66+ return attn_output
67+
68+
69+ @script ()
70+ def _masked_post_mul_sdpa_script (query , key , value , mask ):
71+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
72+ multiplier = op .Constant (value_float = MUL_SCALE_FACTOR )
73+ attn_score = op .MatMul (query , key_transposed )
74+ scaled_attn_score = op .Mul (attn_score , multiplier )
75+ masked_attn_score = op .Add (scaled_attn_score , mask )
76+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
77+ attn_output = op .MatMul (attn_weight , value )
78+ return attn_output
79+
80+
81+ class SDPATestCase :
82+ def __init__ (self , script_func ):
83+ self .script_func = script_func
84+
4285 def get_onnx_model (self ):
4386 if not hasattr (self , "_onnx_model" ):
4487 qkv_type = FLOAT [B , N , S , H ]
4588 mask_type = FLOAT [B , N , S , S ]
46- model_proto = _masked_pre_div_sdpa_script .to_model_proto (
89+ model_proto = self . script_func .to_model_proto (
4790 input_types = [qkv_type , qkv_type , qkv_type , mask_type ], output_types = [qkv_type ]
4891 )
49- model = ir .serde .deserialize_model (model_proto )
50- self ._onnx_model = model
92+ self ._onnx_model = ir .serde .deserialize_model (model_proto )
5193 return self ._onnx_model
5294
5395 def get_ort_inputs (self ):
@@ -63,12 +105,20 @@ def get_ort_inputs(self):
63105
64106
65107class TestSDPAFusion (unittest .TestCase ):
66- def test_sdpa_fusion (self ):
67- test = _MaskedPreDivSDPATestCase ()
68- model = test .get_onnx_model ()
108+ @parameterized .expand (
109+ [
110+ ("pre_div" , _masked_pre_div_sdpa_script ),
111+ ("pre_mul" , _masked_pre_mul_sdpa_script ),
112+ ("post_div" , _masked_post_div_sdpa_script ),
113+ ("post_mul" , _masked_post_mul_sdpa_script ),
114+ ]
115+ )
116+ def test_sdpa_fusion (self , name , script_func ):
117+ test_case = SDPATestCase (script_func )
118+ model = test_case .get_onnx_model ()
69119 onnxscript .optimizer .optimize (model )
70120
71- # inputs = test .get_ort_inputs()
121+ # inputs = test_case .get_ort_inputs()
72122 # original_outputs = ort_run("original", model, inputs)
73123
74124 count = fuse_sdpa (model )
0 commit comments