Skip to content

Commit 4fdf87f

Browse files
authored
Merge pull request #276 from sunyanguomt/sdpa_fix
Bugfix:Resolving image blur issues with SDPA
2 parents efabefb + 71fd817 commit 4fdf87f

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

hyvideo/modules/attenion.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,28 @@ def attention(
101101
if mode == "torch":
102102
if attn_mask is not None and attn_mask.dtype != torch.bool:
103103
attn_mask = attn_mask.to(q.dtype)
104-
x = F.scaled_dot_product_attention(
105-
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
106-
)
104+
if cu_seqlens_q is None:
105+
x = F.scaled_dot_product_attention(
106+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
107+
)
108+
else:
109+
attn1 = F.scaled_dot_product_attention(
110+
q[:, :, :cu_seqlens_q[1]],
111+
k[:, :, :cu_seqlens_kv[1]],
112+
v[:, :, :cu_seqlens_kv[1]],
113+
attn_mask=attn_mask,
114+
dropout_p=drop_rate,
115+
is_causal=causal
116+
)
117+
attn2 = F.scaled_dot_product_attention(
118+
q[:, :, cu_seqlens_q[1]:],
119+
k[:, :, cu_seqlens_kv[1]:],
120+
v[:, :, cu_seqlens_kv[1]:],
121+
attn_mask=None,
122+
dropout_p=drop_rate,
123+
is_causal=False
124+
)
125+
x = torch.cat([attn1, attn2], dim=2)
107126
elif mode == "flash":
108127
x = flash_attn_varlen_func(
109128
q,

0 commit comments

Comments
 (0)