-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathupdates.patch
More file actions
57 lines (50 loc) · 3.29 KB
/
updates.patch
File metadata and controls
57 lines (50 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py
index a895340bd..688d109e0 100644
--- a/src/diffusers/models/transformers/transformer_ltx.py
+++ b/src/diffusers/models/transformers/transformer_ltx.py
@@ -260,7 +260,7 @@ class LTXVideoTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
- ff_output = self.ff(norm_hidden_states)
+ ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype)
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 96d41bb32..6a6e2f761 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -30,6 +30,11 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LTXPipelineOutput
+try:
+ import q8_kernels # noqa
+ from q8_kernels.modules.linear import Q8Linear
+except:
+ Q8Linear = None
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -186,11 +191,11 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
scheduler=scheduler,
)
- self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32
- self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8
- self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1
+ self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ self.transformer_spatial_patch_size = self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
self.transformer_temporal_patch_size = (
- self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1
+ self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
@@ -649,6 +654,11 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+ if Q8Linear is not None and isinstance(self.transformer.transformer_blocks[0].attn1.to_q, Q8Linear):
+ prompt_attention_mask = prompt_attention_mask.to(torch.int64)
+ prompt_attention_mask = prompt_attention_mask.argmin(-1).int().squeeze()
+ prompt_attention_mask[prompt_attention_mask == 0] = max_sequence_length
+
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(