@@ -564,16 +564,15 @@ def encode_images(self, images):
564564 """Encode images to vision states (for i2v)"""
565565 if self .vision_encoder is None :
566566 return None
567- if isinstance (images , torch .Tensor ):
568- images_np = (images .cpu ().permute (0 , 2 , 3 , 1 ).numpy () * 255 ).astype ("uint8" )
569- else :
570- images_np = images
567+ assert images .max () <= 1.0 and images .min () >= - 1.0 , f"Images must be in the range [-1, 1], but got { images .min ()} { images .max ()} "
568+ images = (images + 1 ) / 2 # [-1, 1] -> [0, 1]
569+ images_np = (images .cpu ().permute (0 , 2 , 3 , 1 ).numpy () * 255 ).clip (0 , 255 ).astype ("uint8" )
571570 vision_states = self .vision_encoder .encode_images (images_np )
572571 return vision_states .last_hidden_state .to (device = self .device , dtype = self .transformer .dtype )
573572
574573 def encode_vae (self , images : torch .Tensor ) -> torch .Tensor :
575- if images .max () > 1.0 :
576- images = images / 255.0
574+ if images .max () > 1.0 or images . min () < - 1.0 :
575+ raise ValueError ( f"Images must be in the range [-1, 1], but got { images . min () } { images . max () } " )
577576
578577 if images .ndim == 4 :
579578 images = images .unsqueeze (2 )
@@ -674,7 +673,8 @@ def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
674673 byt5_text_mask = torch .cat (byt5_mask_list , dim = 0 )
675674
676675 vision_states = None
677- if task_type == "i2v" and images is not None :
676+ if task_type == "i2v" :
677+ assert images is not None , '`pixel_values` must be provided for i2v task'
678678 if images .ndim == 5 :
679679 first_frame = images [:, :, 0 , :, :]
680680 else :
0 commit comments