Skip to content

Commit e64e01a

Browse files
committed
restrict data range
1 parent d3c43b0 commit e64e01a

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -564,16 +564,15 @@ def encode_images(self, images):
564564
"""Encode images to vision states (for i2v)"""
565565
if self.vision_encoder is None:
566566
return None
567-
if isinstance(images, torch.Tensor):
568-
images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
569-
else:
570-
images_np = images
567+
assert images.max() <= 1.0 and images.min() >= -1.0, f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}"
568+
images = (images + 1) / 2 # [-1, 1] -> [0, 1]
569+
images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).clip(0, 255).astype("uint8")
571570
vision_states = self.vision_encoder.encode_images(images_np)
572571
return vision_states.last_hidden_state.to(device=self.device, dtype=self.transformer.dtype)
573572

574573
def encode_vae(self, images: torch.Tensor) -> torch.Tensor:
575-
if images.max() > 1.0:
576-
images = images / 255.0
574+
if images.max() > 1.0 or images.min() < -1.0:
575+
raise ValueError(f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}")
577576

578577
if images.ndim == 4:
579578
images = images.unsqueeze(2)
@@ -674,7 +673,8 @@ def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
674673
byt5_text_mask = torch.cat(byt5_mask_list, dim=0)
675674

676675
vision_states = None
677-
if task_type == "i2v" and images is not None:
676+
if task_type == "i2v":
677+
assert images is not None, '`pixel_values` must be provided for i2v task'
678678
if images.ndim == 5:
679679
first_frame = images[:, :, 0, :, :]
680680
else:

0 commit comments

Comments
 (0)