Conversation
…te_log_prob_for_timestep
Co-authored-by: Cursor <cursoragent@cursor.com>
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for:
This rule is failing.
|
Pre-commit checks failedHi @shijiew555, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
There was a problem hiding this comment.
Code Review
This pull request introduces a reinforcement learning (RL) training pipeline for FastVideo, specifically implementing GRPO (Group Relative Policy Optimization) for video generation models. Key additions include a new RL pipeline structure, reward model infrastructure (with an OCR-based reward), and dataset handling for RL prompts. Several issues were identified, including a critical security vulnerability involving a hardcoded W&B API key, hardcoded validation parameters, and inconsistent configuration documentation. I have provided feedback to address these issues and improve the robustness of the implementation.
|
|
||
| export WANDB_BASE_URL="https://api.wandb.ai" | ||
| export WANDB_MODE=online | ||
| export WANDB_API_KEY="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" |
There was a problem hiding this comment.
A hardcoded W&B API key has been committed. This is a critical security vulnerability. API keys and other secrets should never be hardcoded in the source code. Please remove the key and use a secure method for providing credentials, such as environment variables or a secrets management system.
| export WANDB_API_KEY="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" | |
| export WANDB_API_KEY="${WANDB_API_KEY}" # Key should be provided via environment variable |
| Returns: | ||
| Reward tensor [B] with averaged OCR similarity scores across frames | ||
| """ | ||
| prompts = [prompt.split('"')[1] for prompt in prompts] |
There was a problem hiding this comment.
The prompt parsing logic prompt.split('"')[1] is brittle. It assumes every prompt contains exactly one pair of double quotes and will raise an IndexError if a prompt does not follow this format. This could crash the reward computation. Consider using a more robust method, like regular expressions, to extract the quoted text, and include error handling for prompts that don't match the expected format.
| sampling_param.height = 480 #training_args.num_height | ||
| sampling_param.width = 832 #training_args.num_width |
There was a problem hiding this comment.
The height and width for the validation batch are hardcoded to 480 and 832, respectively. The commented-out code suggests these values should be taken from training_args. Hardcoding these values can lead to incorrect validation behavior if the training configuration changes. Please use the values from training_args as intended.
| sampling_param.height = 480 #training_args.num_height | |
| sampling_param.width = 832 #training_args.num_width | |
| sampling_param.height = training_args.num_height | |
| sampling_param.width = training_args.num_width |
| NUM_GPUS=8 | ||
|
|
||
| export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 |
There was a problem hiding this comment.
| # epoch = epoch*num_batches_per_epoch+i and skips the first 2 epochs, so first real batch | ||
| # uses seed + 2*num_batches_per_epoch; we use seed+step+4 so step 0,1,... matches that. | ||
| g = torch.Generator() | ||
| g.manual_seed(self.seed + self.step + 3) |
There was a problem hiding this comment.
The comment on lines 85-87 states that the seed is calculated as seed+step+4 to align with flow_grpo, but the code uses self.seed + self.step + 3. This discrepancy could lead to confusion and incorrect behavior if the comment is trusted. Please update either the code or the comment to ensure they are consistent.
| g.manual_seed(self.seed + self.step + 3) | |
| g.manual_seed(self.seed + self.step + 4) |
| test_num_workers: int = 8, | ||
| num_replicas: int = 1, | ||
| rank: int = 0, | ||
| ) -> tuple[DataLoader, DataLoader]: |
There was a problem hiding this comment.
The return type hint for build_rl_prompt_dataloader is -> tuple[DataLoader, DataLoader], but the function actually returns a tuple of five elements: (train_dataloader, test_dataloader, train_dataset, test_dataset, train_sampler). Please update the type hint to match the implementation for clarity and correctness.
| ) -> tuple[DataLoader, DataLoader]: | |
| ) -> tuple[DataLoader, DataLoader, Dataset, Dataset, Sampler]: |
| for f in TRAINING_BATCH_SAMPLE_TENSOR_FIELDS: | ||
| t = getattr(batch, f) | ||
| if t is not None: | ||
| setattr(sub, f, t[s:e].clone() if t.is_cuda else t[s:e]) |
There was a problem hiding this comment.
There's an inconsistency in how tensors are handled based on their device. CUDA tensors are cloned, while CPU tensors are not, creating a view. This can lead to unexpected side effects if the sliced CPU tensor is modified elsewhere. For safety and consistency, it's better to always clone the tensor regardless of its device.
| setattr(sub, f, t[s:e].clone() if t.is_cuda else t[s:e]) | |
| setattr(sub, f, t[s:e].clone()) |
| except Exception as e: | ||
| dist=len(prompt) |
There was a problem hiding this comment.
Catching a broad Exception without logging the specific error can hide underlying issues and make debugging difficult. It's better to log the exception to aid in troubleshooting.
| except Exception as e: | |
| dist=len(prompt) | |
| except Exception as e: | |
| logger.warning(f"OCR failed for a frame: {e}") | |
| dist=len(prompt) |
| assert abs(sum(reward_weights) - 1.0) < 1e-6, \ | ||
| f"Reward weights must sum to 1.0, got {sum(reward_weights)}" | ||
|
|
There was a problem hiding this comment.
Pre-commit checks failedHi @shijiew555, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Purpose
Implements RL pipeline on Wan 2.1 1.3B model. Port T2V GRPO pipeline from flow_grpo .
Changes
Code structure:
Test Plan
To run RL training pipeline on Wan 2.1 1.3B model:
Test Results
Reward curve over training steps:

Wandb link to this run:
https://wandb.ai/irmchen-ucsd/wan_t2v_grpo/runs/9gbax186?nw=nwusershijiew21
Checklist
pre-commit run --all-filesand fixed all issuesFor model/pipeline changes, also check: