1515
1616
1717class FuseBiasMHA (pattern .RewriteRuleClassBase ):
18- def __init__ (
19- self ,
20- name ,
21- * ,
22- q_no_bias : bool ,
23- k_no_bias : bool ,
24- v_no_bias : bool ,
25- ):
26- super ().__init__ (name )
27- self ._q_no_bias = q_no_bias
28- self ._k_no_bias = k_no_bias
29- self ._v_no_bias = v_no_bias
30-
3118 def pattern (
3219 self ,
3320 op ,
@@ -43,18 +30,21 @@ def pattern(
4330 num_heads ,
4431 # scale,
4532 ):
46- if not self ._q_no_bias :
47- query_BSD = op .Add (query_matmul , q_bias )
48- else :
49- query_BSD = query_matmul
50- if not self ._k_no_bias :
51- key_BSD = op .Add (key_matmul , k_bias )
52- else :
53- key_BSD = key_matmul
54- if not self ._v_no_bias :
55- value_BSD = op .Add (value_matmul , v_bias )
56- else :
57- value_BSD = value_matmul
33+ query_BSD = pattern .OrValue (
34+ [op .Add (query_matmul , q_bias ), query_matmul ],
35+ tag_var = "has_q_bias" ,
36+ tag_values = [True , False ],
37+ )
38+ key_BSD = pattern .OrValue (
39+ [op .Add (key_matmul , k_bias ), key_matmul ],
40+ tag_var = "has_k_bias" ,
41+ tag_values = [True , False ],
42+ )
43+ value_BSD = pattern .OrValue (
44+ [op .Add (value_matmul , v_bias ), value_matmul ],
45+ tag_var = "has_v_bias" ,
46+ tag_values = [True , False ],
47+ )
5848
5949 return op .MultiHeadAttention (
6050 query_BSD ,
@@ -72,14 +62,20 @@ def pattern(
7262
7363 def check (
7464 self ,
75- op ,
65+ context ,
7666 query_matmul ,
7767 key_matmul ,
7868 value_matmul ,
69+ has_q_bias ,
70+ has_k_bias ,
71+ has_v_bias ,
7972 ** _ ,
8073 ) -> pattern .MatchResult : # type: ignore[name-defined]
8174 check_result = pattern .MatchResult ()
8275
76+ if not (has_q_bias or has_k_bias or has_v_bias ):
77+ return check_result .fail ("None of query, key, or value have a bias." )
78+
8379 self .bindings : dict [str , Dim ] = {}
8480
8581 def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
@@ -139,15 +135,15 @@ def rewrite(
139135 # scale,
140136 ** _ ,
141137 ):
142- if self . _q_no_bias :
138+ if q_bias is None :
143139 q_bias = op .Constant (
144140 value = ir .tensor (numpy .zeros ((self .Dh_q ,), dtype = query_matmul .dtype .numpy ()))
145141 )
146- if self . _k_no_bias :
142+ if k_bias is None :
147143 k_bias = op .Constant (
148144 value = ir .tensor (numpy .zeros ((self .Dh_k ,), dtype = key_matmul .dtype .numpy ()))
149145 )
150- if self . _v_no_bias :
146+ if v_bias is None :
151147 v_bias = op .Constant (
152148 value = ir .tensor (numpy .zeros ((self .Dh_v ,), dtype = value_matmul .dtype .numpy ()))
153149 )
@@ -167,30 +163,7 @@ def rewrite(
167163 )
168164
169165
170- parameter_combinations = [
171- {
172- "q_no_bias" : q_no_bias ,
173- "k_no_bias" : k_no_bias ,
174- "v_no_bias" : v_no_bias ,
175- }
176- for q_no_bias in [False , True ]
177- for k_no_bias in [False , True ]
178- for v_no_bias in [False , True ]
179- ]
180-
181- # Dynamically create the rules
182- fuse_mha_bias_rules = pattern .RewriteRuleSet (
183- [
184- FuseBiasMHA .rule (
185- f"MHABias{ '_NoQBias' if params ['q_no_bias' ] else '' } "
186- f"{ '_NoKBias' if params ['k_no_bias' ] else '' } "
187- f"{ '_NoVBias' if params ['v_no_bias' ] else '' } " ,
188- ** params ,
189- )
190- # Exclude (True, True, True) as it is an unnecessary case
191- for params in parameter_combinations [:- 1 ]
192- ]
193- )
166+ fuse_mha_bias_rules = pattern .RewriteRuleSet ([FuseBiasMHA .rule ()])
194167
195168
196169fuse_mha_bias = _fusion_utils .apply_fusion_rules (fuse_mha_bias_rules )
0 commit comments