Skip to content

Commit a20d994

Browse files
committed
Fix leapfusion and V2V
1 parent 3c230c7 commit a20d994

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def prepare_latents(
252252
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
253253
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
254254
)
255+
if latents is not None:
256+
latents = latents.to(device)
255257
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.base_dtype)
256258

257259
if freenoise:
@@ -310,11 +312,13 @@ def prepare_latents(
310312

311313
if frames_needed > current_frames:
312314
repeat_factor = frames_needed - current_frames
313-
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
314-
latents = torch.cat((additional_frame, latents), dim=1)
315-
self.additional_frames = repeat_factor
315+
additional_frame = torch.randn((latents.shape[0], latents.shape[1], repeat_factor, latents.shape[3], latents.shape[4]), dtype=latents.dtype, device=latents.device)
316+
latents = torch.cat((additional_frame, latents), dim=2)
317+
logger.info(f"Frames needed more than current frames, adding {repeat_factor} frames")
316318
elif frames_needed < current_frames:
317-
latents = latents[:, :frames_needed, :, :, :]
319+
latents = latents[:, :, :frames_needed, :, :]
320+
logger.info(f"Frames needed less than current frames, cutting down to {frames_needed}")
321+
318322
latents = latents * (1 - latent_timestep / 1000) + latent_timestep / 1000 * noise
319323
print("latents shape:", latents.shape)
320324
elif image_cond_latents is not None and i2v_stability:
@@ -734,7 +738,7 @@ def __call__(
734738
t_expand = t.repeat(latent_model_input.shape[0])
735739

736740
if leapfusion_img2vid:
737-
latent_model_input[:, :, [0,], :, :] = original_latents[:, :, [0,], :, :].to(latent_model_input)
741+
latent_model_input[:, :, [0], :, :] = original_latents[:, :, [0], :, :].to(latent_model_input)
738742

739743
if image_cond_latents is not None and not use_context_schedule:
740744
if i2v_condition_type == "latent_concat":

nodes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,11 @@ def loadmodel(self, model, base_precision, load_device, quantization,
322322
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)
323323

324324
in_channels = sd["img_in.proj.weight"].shape[1]
325-
print("In channels: ", in_channels)
326-
if in_channels == 16:
325+
if in_channels == 16 and "i2v" in model.lower():
327326
i2v_condition_type = "token_replace"
328-
elif in_channels == 33 or in_channels == 32:
327+
else:
329328
i2v_condition_type = "latent_concat"
329+
log.info(f"Condition type: {i2v_condition_type}")
330330

331331
guidance_embed = sd.get("guidance_in.mlp.0.weight", False) is not False
332332

0 commit comments

Comments
 (0)