@@ -87,19 +87,17 @@ def pattern(
8787 shape_B111 ,
8888 ):
8989 # Reshape query from (B, S, D) to (B, S, H, D/H)
90- query_BSHDh = op .Reshape (query_BSD , _allow_other_inputs = True , _outputs = ["query_BSHDh" ])
90+ query_BSHDh = op .Reshape (query_BSD , pattern . ANY_VALUE , _outputs = ["query_BSHDh" ])
9191 # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
9292 query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
9393
9494 # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
95- key_BSHkvDh = op .Reshape (key_BSDkv , _allow_other_inputs = True , _outputs = ["key_BSHkvDh" ])
95+ key_BSHkvDh = op .Reshape (key_BSDkv , pattern . ANY_VALUE , _outputs = ["key_BSHkvDh" ])
9696 # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
9797 key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
9898
9999 # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H)
100- value_BSHkvDh = op .Reshape (
101- value_BSDkv , _allow_other_inputs = True , _outputs = ["value_BSHkvDh" ]
102- )
100+ value_BSHkvDh = op .Reshape (value_BSDkv , pattern .ANY_VALUE , _outputs = ["value_BSHkvDh" ])
103101 # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
104102 value_BHkvSDh = op .Transpose (value_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
105103
@@ -129,18 +127,18 @@ def pattern(
129127
130128 key_seq_BHkvTDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
131129 key_seq_BHkv1TDh = op .Unsqueeze (key_seq_BHkvTDh , 2 )
132- key_seq_BHkvGTDh = op .Expand (key_seq_BHkv1TDh , _allow_other_inputs = True )
130+ key_seq_BHkvGTDh = op .Expand (key_seq_BHkv1TDh , pattern . ANY_VALUE )
133131 key_seq_BHTDh = op .Reshape (
134- key_seq_BHkvGTDh , _allow_other_inputs = True , _outputs = ["key_seq_BHTDh" ]
132+ key_seq_BHkvGTDh , pattern . ANY_VALUE , _outputs = ["key_seq_BHTDh" ]
135133 )
136134
137135 # Concatenate past_value cache and current value, expand across heads
138136 # that share key/value.
139137 value_seq_BHkvTDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
140138 value_seq_BHkv1TDh = op .Unsqueeze (value_seq_BHkvTDh , 2 )
141- value_seq_BHkvGTDh = op .Expand (value_seq_BHkv1TDh , _allow_other_inputs = True )
139+ value_seq_BHkvGTDh = op .Expand (value_seq_BHkv1TDh , pattern . ANY_VALUE )
142140 value_seq_BHTDh = op .Reshape (
143- value_seq_BHkvGTDh , _allow_other_inputs = True , _outputs = ["value_seq_BHTDh" ]
141+ value_seq_BHkvGTDh , pattern . ANY_VALUE , _outputs = ["value_seq_BHTDh" ]
144142 )
145143
146144 mask = causal_mask_pattern (op , input_ids , some_kv_cache , shape_B111 )
@@ -158,7 +156,7 @@ def pattern(
158156 attention_BSHDh = op .Transpose (attention_BHSDh , perm = [0 , 2 , 1 , 3 ])
159157 # Reshape back to (B, S, D)
160158 attention_BSD = op .Reshape (
161- attention_BSHDh , _allow_other_inputs = True , _outputs = ["attention_BSD" ]
159+ attention_BSHDh , pattern . ANY_VALUE , _outputs = ["attention_BSD" ]
162160 )
163161 return attention_BSD , key_seq_BHkvTDh , value_seq_BHkvTDh
164162
0 commit comments