@@ -709,7 +709,7 @@ def _prepare_vision_states(self, reference_image, target_resolution, latents, de
709709 torch.Tensor or None: Vision states tensor or None if vision encoder is unavailable.
710710 """
711711 if reference_image is None :
712- vision_states = torch .zeros (latents .shape [0 ], self .config .vision_num_semantic_tokens , self .config .vision_states_dim ). to ( latents .device )
712+ vision_states = torch .zeros (latents .shape [0 ], self .config .vision_num_semantic_tokens , self .config .vision_states_dim , device = latents .device , dtype = latents . dtype )
713713 else :
714714 reference_image = np .array (reference_image ) if isinstance (reference_image , Image .Image ) else reference_image
715715 if len (reference_image .shape ) == 4 :
@@ -753,11 +753,11 @@ def _prepare_cond_latents(self, task_type, cond_latents, latents, multitask_mask
753753 latents_concat = cond_latents .repeat (1 , 1 , latents .shape [2 ], 1 , 1 )
754754 latents_concat [:, :, 1 :, :, :] = 0.0
755755 else :
756- latents_concat = torch .zeros (latents . shape [ 0 ], latents . shape [ 1 ], latents . shape [ 2 ], latents . shape [ 3 ], latents . shape [ 4 ]). to ( latents . device )
756+ latents_concat = torch .zeros_like (latents )
757757
758- mask_zeros = torch .zeros (latents .shape [0 ], 1 , latents .shape [2 ], latents .shape [3 ], latents .shape [4 ])
759- mask_ones = torch .ones (latents .shape [0 ], 1 , latents .shape [2 ], latents .shape [3 ], latents .shape [4 ])
760- mask_concat = merge_tensor_by_mask (mask_zeros . cpu () , mask_ones . cpu () , mask = multitask_mask .cpu (), dim = 2 ). to (device = latents .device )
758+ mask_zeros = torch .zeros (latents .shape [0 ], 1 , latents .shape [2 ], latents .shape [3 ], latents .shape [4 ], device = latents . device )
759+ mask_ones = torch .ones (latents .shape [0 ], 1 , latents .shape [2 ], latents .shape [3 ], latents .shape [4 ], device = latents . device )
760+ mask_concat = merge_tensor_by_mask (mask_zeros , mask_ones , mask = multitask_mask .to (device = latents .device ), dim = 2 )
761761
762762 cond_latents = torch .concat ([latents_concat , mask_concat ], dim = 1 )
763763
@@ -1203,23 +1203,23 @@ def __call__(
12031203
12041204 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
12051205
1206- t_expand = t .repeat (latent_model_input .shape [0 ])
1206+ t_expand = t .expand (latent_model_input .shape [0 ])
12071207 if self .use_meanflow :
12081208 if i == len (timesteps ) - 1 :
12091209 timesteps_r = torch .tensor ([0.0 ], device = self .execution_device )
12101210 else :
12111211 timesteps_r = timesteps [i + 1 ]
1212- timesteps_r = timesteps_r .repeat (latent_model_input .shape [0 ])
1212+ timesteps_r = timesteps_r .expand (latent_model_input .shape [0 ])
12131213 else :
12141214 timesteps_r = None
12151215
12161216 guidance_expand = (
1217- torch .tensor (
1218- [embedded_guidance_scale ] * latent_model_input .shape [0 ],
1219- dtype = torch .float32 ,
1217+ torch .full (
1218+ (latent_model_input .shape [0 ],),
1219+ embedded_guidance_scale * 1000.0 ,
1220+ dtype = self .target_dtype ,
12201221 device = device ,
1221- ).to (self .target_dtype )
1222- * 1000.0
1222+ )
12231223 if embedded_guidance_scale is not None
12241224 else None
12251225 )
0 commit comments