@@ -101,9 +101,28 @@ def attention(
101101 if mode == "torch" :
102102 if attn_mask is not None and attn_mask .dtype != torch .bool :
103103 attn_mask = attn_mask .to (q .dtype )
104- x = F .scaled_dot_product_attention (
105- q , k , v , attn_mask = attn_mask , dropout_p = drop_rate , is_causal = causal
106- )
104+ if cu_seqlens_q is None :
105+ x = F .scaled_dot_product_attention (
106+ q , k , v , attn_mask = attn_mask , dropout_p = drop_rate , is_causal = causal
107+ )
108+ else :
109+ attn1 = F .scaled_dot_product_attention (
110+ q [:, :, :cu_seqlens_q [1 ]],
111+ k [:, :, :cu_seqlens_kv [1 ]],
112+ v [:, :, :cu_seqlens_kv [1 ]],
113+ attn_mask = attn_mask ,
114+ dropout_p = drop_rate ,
115+ is_causal = causal
116+ )
117+ attn2 = F .scaled_dot_product_attention (
118+ q [:, :, cu_seqlens_q [1 ]:],
119+ k [:, :, cu_seqlens_kv [1 ]:],
120+ v [:, :, cu_seqlens_kv [1 ]:],
121+ attn_mask = None ,
122+ dropout_p = drop_rate ,
123+ is_causal = False
124+ )
125+ x = torch .cat ([attn1 , attn2 ], dim = 2 )
107126 elif mode == "flash" :
108127 x = flash_attn_varlen_func (
109128 q ,
0 commit comments