@@ -252,6 +252,8 @@ def prepare_latents(
252252 f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
253253 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
254254 )
255+ if latents is not None :
256+ latents = latents .to (device )
255257 noise = randn_tensor (shape , generator = generator , device = device , dtype = self .base_dtype )
256258
257259 if freenoise :
@@ -310,11 +312,13 @@ def prepare_latents(
310312
311313 if frames_needed > current_frames :
312314 repeat_factor = frames_needed - current_frames
313- additional_frame = torch .randn ((latents .size ( 0 ), repeat_factor , latents .size ( 2 ), latents .size ( 3 ) , latents .size ( 4 ) ), dtype = latents .dtype , device = latents .device )
314- latents = torch .cat ((additional_frame , latents ), dim = 1 )
315- self . additional_frames = repeat_factor
315+ additional_frame = torch .randn ((latents .shape [ 0 ], latents .shape [ 1 ], repeat_factor , latents .shape [ 3 ] , latents .shape [ 4 ] ), dtype = latents .dtype , device = latents .device )
316+ latents = torch .cat ((additional_frame , latents ), dim = 2 )
317+ logger . info ( f"Frames needed more than current frames, adding { repeat_factor } frames" )
316318 elif frames_needed < current_frames :
317- latents = latents [:, :frames_needed , :, :, :]
319+ latents = latents [:, :, :frames_needed , :, :]
320+ logger .info (f"Frames needed less than current frames, cutting down to { frames_needed } " )
321+
318322 latents = latents * (1 - latent_timestep / 1000 ) + latent_timestep / 1000 * noise
319323 print ("latents shape:" , latents .shape )
320324 elif image_cond_latents is not None and i2v_stability :
@@ -734,7 +738,7 @@ def __call__(
734738 t_expand = t .repeat (latent_model_input .shape [0 ])
735739
736740 if leapfusion_img2vid :
737- latent_model_input [:, :, [0 , ], :, :] = original_latents [:, :, [0 , ], :, :].to (latent_model_input )
741+ latent_model_input [:, :, [0 ], :, :] = original_latents [:, :, [0 ], :, :].to (latent_model_input )
738742
739743 if image_cond_latents is not None and not use_context_schedule :
740744 if i2v_condition_type == "latent_concat" :
0 commit comments