Skip to content

Commit b9cbe29

Browse files
committed
restrict data range
1 parent d3c43b0 commit b9cbe29

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ Replace the `create_dummy_dataloader()` function in `train.py` with your own imp
431431

432432
- **Required fields:**
433433
- `"pixel_values"`: `torch.Tensor` - Video: `[B, C, F, H, W]` or Image: `[B, C, H, W]`
434+
- Pixel values must be in range `[-1, 1]`
434435
- Note: For video data, temporal dimension F must be `4n+1` (e.g., 1, 5, 9, 13, 17, ...)
435436
- `"text"`: `List[str]` - Text prompts for each sample
436437
- `"data_type"`: `str` - `"video"` or `"image"`

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ HunyuanVideo-1.5 使用 **Muon 优化器**进行训练,该优化器能够加
430430

431431
- **必需字段:**
432432
- `"pixel_values"`: `torch.Tensor` - 视频:`[B, C, F, H, W]` 或图像:`[B, C, H, W]`
433+
- 像素值必须在 `[-1, 1]` 范围内
433434
- 注意:对于视频数据,时间维度 F 必须是 `4n+1`(例如:1, 5, 9, 13, 17, ...)
434435
- `"text"`: `List[str]` - 每个样本的文本提示词
435436
- `"data_type"`: `str` - `"video"``"image"`

train.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- Replace the `create_dummy_dataloader()` function with your own implementation
2525
- Your dataloader should return batches with the following format:
2626
* "pixel_values": torch.Tensor - Video: [B, C, F, H, W] or Image: [B, C, H, W]
27+
Pixel values must be in range [-1, 1]
2728
Note: For video data, temporal dimension F must be 4n+1 (e.g., 1, 5, 9, 13, 17, ...)
2829
* "text": List[str] - Text prompts for each sample
2930
* "data_type": str - "video" or "image"
@@ -564,16 +565,15 @@ def encode_images(self, images):
564565
"""Encode images to vision states (for i2v)"""
565566
if self.vision_encoder is None:
566567
return None
567-
if isinstance(images, torch.Tensor):
568-
images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).astype("uint8")
569-
else:
570-
images_np = images
568+
assert images.max() <= 1.0 and images.min() >= -1.0, f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}"
569+
images = (images + 1) / 2 # [-1, 1] -> [0, 1]
570+
images_np = (images.cpu().permute(0, 2, 3, 1).numpy() * 255).clip(0, 255).astype("uint8")
571571
vision_states = self.vision_encoder.encode_images(images_np)
572572
return vision_states.last_hidden_state.to(device=self.device, dtype=self.transformer.dtype)
573573

574574
def encode_vae(self, images: torch.Tensor) -> torch.Tensor:
575-
if images.max() > 1.0:
576-
images = images / 255.0
575+
if images.max() > 1.0 or images.min() < -1.0:
576+
raise ValueError(f"Images must be in the range [-1, 1], but got {images.min()} {images.max()}")
577577

578578
if images.ndim == 4:
579579
images = images.unsqueeze(2)
@@ -623,7 +623,8 @@ def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
623623
624624
Expected batch format:
625625
{
626-
"pixel_values": torch.Tensor, # [B, C, F, H, W] for video or [B, C, H, W] for image
626+
"pixel_values": torch.Tensor, # [B, C, F, H, W] for video or [B, C, H, W] for image
627+
# Pixel values must be in range [-1, 1]
627628
"text": List[str],
628629
"data_type": str, # "image" or "video"
629630
"byt5_text_ids": Optional[torch.Tensor],
@@ -674,7 +675,8 @@ def prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
674675
byt5_text_mask = torch.cat(byt5_mask_list, dim=0)
675676

676677
vision_states = None
677-
if task_type == "i2v" and images is not None:
678+
if task_type == "i2v":
679+
assert images is not None, '`pixel_values` must be provided for i2v task'
678680
if images.ndim == 5:
679681
first_frame = images[:, :, 0, :, :]
680682
else:
@@ -1003,8 +1005,7 @@ def create_dummy_dataloader(config: TrainingConfig):
10031005
- "pixel_values": torch.Tensor
10041006
* For video: shape [B, C, F, H, W] where F is the number of frames
10051007
* For image: shape [B, C, H, W] (will be automatically expanded to [B, C, 1, H, W])
1006-
* Pixel values should be in range [0, 255] (will be normalized to [0, 1] then [-1, 1])
1007-
* Or in range [0, 1] (will be normalized to [-1, 1])
1008+
* Pixel values must be in range [-1, 1]
10081009
* Data type: torch.float32
10091010
* Note: For video data, temporal dimension F must be 4n+1 (e.g., 1, 5, 9, 13, 17, 21, ...)
10101011
to satisfy VAE requirements. The dataset should ensure this before returning data.
@@ -1065,7 +1066,8 @@ def __len__(self):
10651066

10661067
def __getitem__(self, idx):
10671068
# Video: temporal dimension must be 4n+1, using 17 frames
1068-
data = torch.randn(3, 17, 64, 64)
1069+
# Generate data in range [-1, 1]
1070+
data = torch.rand(3, 17, 64, 64) * 2.0 - 1.0 # [0, 1] -> [-1, 1]
10691071
data_type = "video"
10701072

10711073
return {

0 commit comments

Comments
 (0)