|
24 | 24 | - Replace the `create_dummy_dataloader()` function with your own implementation |
25 | 25 | - Your dataloader should return batches with the following format: |
26 | 26 | * "pixel_values": torch.Tensor - Video: [B, C, F, H, W] or Image: [B, C, H, W] |
| 27 | + Pixel values must be in range [-1, 1] |
27 | 28 | Note: For video data, temporal dimension F must be 4n+1 (e.g., 1, 5, 9, 13, 17, ...) |
28 | 29 | * "text": List[str] - Text prompts for each sample |
29 | 30 | * "data_type": str - "video" or "image" |
@@ -564,16 +565,15 @@ def encode_images(self, images): |
564 | 565 | """Encode images to vision states (for i2v)""" |
565 | 566 | if self.vision_encoder is None: |
566 | 567 | return None |
567 | | - if isinstance(images, torch.Tensor): |
568 | | - images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8") |
569 | | - else: |
570 | | - images_np = images |
| 568 | + assert images.max() <= 1.0 and images.min() >= -1.0, f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}" |
| 569 | + images = (images + 1) / 2 # [-1, 1] -> [0, 1] |
| 570 | + images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).clip(0, 255).astype("uint8") |
571 | 571 | vision_states = self.vision_encoder.encode_images(images_np) |
572 | 572 | return vision_states.last_hidden_state.to(device=self.device, dtype=self.transformer.dtype) |
573 | 573 |
|
574 | 574 | def encode_vae(self, images: torch.Tensor) -> torch.Tensor: |
575 | | - if images.max() > 1.0: |
576 | | - images = images / 255.0 |
| 575 | + if images.max() > 1.0 or images.min() < -1.0: |
| 576 | + raise ValueError(f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}") |
577 | 577 |
|
578 | 578 | if images.ndim == 4: |
579 | 579 | images = images.unsqueeze(2) |
@@ -623,7 +623,8 @@ def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: |
623 | 623 | |
624 | 624 | Expected batch format: |
625 | 625 | { |
626 | | - "pixel_values": torch.Tensor, # [B, C, F, H, W] for video or [B, C, H, W] for image |
| 626 | + "pixel_values": torch.Tensor, # [B, C, F, H, W] for video or [B, C, H, W] for image |
| 627 | + # Pixel values must be in range [-1, 1] |
627 | 628 | "text": List[str], |
628 | 629 | "data_type": str, # "image" or "video" |
629 | 630 | "byt5_text_ids": Optional[torch.Tensor], |
@@ -674,7 +675,8 @@ def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: |
674 | 675 | byt5_text_mask = torch.cat(byt5_mask_list, dim=0) |
675 | 676 |
|
676 | 677 | vision_states = None |
677 | | - if task_type == "i2v" and images is not None: |
| 678 | + if task_type == "i2v": |
| 679 | + assert images is not None, '`pixel_values` must be provided for i2v task' |
678 | 680 | if images.ndim == 5: |
679 | 681 | first_frame = images[:, :, 0, :, :] |
680 | 682 | else: |
@@ -1003,8 +1005,7 @@ def create_dummy_dataloader(config: TrainingConfig): |
1003 | 1005 | - "pixel_values": torch.Tensor |
1004 | 1006 | * For video: shape [B, C, F, H, W] where F is the number of frames |
1005 | 1007 | * For image: shape [B, C, H, W] (will be automatically expanded to [B, C, 1, H, W]) |
1006 | | - * Pixel values should be in range [0, 255] (will be normalized to [0, 1] then [-1, 1]) |
1007 | | - * Or in range [0, 1] (will be normalized to [-1, 1]) |
| 1008 | + * Pixel values must be in range [-1, 1] |
1008 | 1009 | * Data type: torch.float32 |
1009 | 1010 | * Note: For video data, temporal dimension F must be 4n+1 (e.g., 1, 5, 9, 13, 17, 21, ...) |
1010 | 1011 | to satisfy VAE requirements. The dataset should ensure this before returning data. |
@@ -1065,7 +1066,8 @@ def __len__(self): |
1065 | 1066 |
|
1066 | 1067 | def __getitem__(self, idx): |
1067 | 1068 | # Video: temporal dimension must be 4n+1, using 17 frames |
1068 | | - data = torch.randn(3, 17, 64, 64) |
| 1069 | + # Generate data in range [-1, 1] |
| 1070 | + data = torch.rand(3, 17, 64, 64) * 2.0 - 1.0 # [0, 1] -> [-1, 1] |
1069 | 1071 | data_type = "video" |
1070 | 1072 |
|
1071 | 1073 | return { |
|
0 commit comments