Skip to content

[feat] Add Wan2.1 RL Pipeline#1222

Open
shijiew555 wants to merge 51 commits intohao-ai-lab:mainfrom
Gary-ChenJL:align_debug
Open

[feat] Add Wan2.1 RL Pipeline#1222
shijiew555 wants to merge 51 commits intohao-ai-lab:mainfrom
Gary-ChenJL:align_debug

Conversation

@shijiew555
Copy link
Copy Markdown
Collaborator

@shijiew555 shijiew555 commented Apr 8, 2026

Purpose

Implements RL pipeline on Wan 2.1 1.3B model. Port T2V GRPO pipeline from flow_grpo .

Changes

Code structure:

FastVideo/
├── examples/training/rl/           # Runscripts
│   ├── finetune_t2v_grpo.sh        # Single-GPU
│   ├── finetune_t2v_grpo_4gpu.sh   # Multi-GPU (4)
│   └── validation.json
│
├── data/ocr/                       # RL prompt dataset (train.txt, test.txt)
│
├── fastvideo/
│   ├── training/
│   │   ├── wan_rl_training_pipeline.py   # Wan RL entry → RLPipeline
│   │   └── rl/                            # RL core
│   │       ├── rl_pipeline.py            # RLPipeline: collect rollouts, reward, advantage, GRPO loss
│   │       ├── rl_utils.py
│   │       ├── stat_tracking.py          # Per-prompt advantage normalization
│   │       ├── wan_grpo_utils.py
│   │       └── rewards/                  # Reward models
│   │           ├── rewards.py            # MultiRewardAggregator, create_reward_models
│   │           ├── ocr.py                # OCR reward (current)
│   │           └── base.py
│   │
│   ├── dataset/
│   │   └── rl_prompt_dataset.py    # RL prompt dataloader (text / geneval), KRepeatSampler
│   │
│   └── pipelines/stages/
│       └── denoising.py            # Rollout generation: logprob + trajectory in inference path

Test Plan

To run RL training pipeline on Wan 2.1 1.3B model:

bash examples/training/rl/finetune_t2v_grpo_4gpu.sh

Test Results

Reward curve over training steps:
image

Wandb link to this run:
https://wandb.ai/irmchen-ucsd/wan_t2v_grpo/runs/9gbax186?nw=nwusershijiew21

Checklist

  • I ran pre-commit run --all-files and fixed all issues
  • I added or updated tests for my changes
  • I updated documentation if needed
  • I considered GPU memory impact of my changes

For model/pipeline changes, also check:

  • I verified SSIM regression tests pass
  • I updated the support matrix if adding a new model

@mergify mergify bot added type: feat New feature or capability scope: training Training pipeline, methods, configs scope: inference Inference pipeline, serving, CLI scope: data Data preprocessing, datasets scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) labels Apr 8, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 8, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for:

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success=fastcheck-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 8, 2026

Pre-commit checks failed

Hi @shijiew555, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a reinforcement learning (RL) training pipeline for FastVideo, specifically implementing GRPO (Group Relative Policy Optimization) for video generation models. Key additions include a new RL pipeline structure, reward model infrastructure (with an OCR-based reward), and dataset handling for RL prompts. Several issues were identified, including a critical security vulnerability involving a hardcoded W&B API key, hardcoded validation parameters, and inconsistent configuration documentation. I have provided feedback to address these issues and improve the robustness of the implementation.


export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=online
export WANDB_API_KEY="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"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

A hardcoded W&B API key has been committed. This is a critical security vulnerability. API keys and other secrets should never be hardcoded in the source code. Please remove the key and use a secure method for providing credentials, such as environment variables or a secrets management system.

Suggested change
export WANDB_API_KEY="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"
export WANDB_API_KEY="${WANDB_API_KEY}" # Key should be provided via environment variable

Returns:
Reward tensor [B] with averaged OCR similarity scores across frames
"""
prompts = [prompt.split('"')[1] for prompt in prompts]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prompt parsing logic prompt.split('"')[1] is brittle. It assumes every prompt contains exactly one pair of double quotes and will raise an IndexError if a prompt does not follow this format. This could crash the reward computation. Consider using a more robust method, like regular expressions, to extract the quoted text, and include error handling for prompts that don't match the expected format.

Comment thread fastvideo/training/rl/rl_pipeline.py Outdated
Comment on lines +474 to +475
sampling_param.height = 480 #training_args.num_height
sampling_param.width = 832 #training_args.num_width
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The height and width for the validation batch are hardcoded to 480 and 832, respectively. The commented-out code suggests these values should be taken from training_args. Hardcoding these values can lead to incorrect validation behavior if the training configuration changes. Please use the values from training_args as intended.

Suggested change
sampling_param.height = 480 #training_args.num_height
sampling_param.width = 832 #training_args.num_width
sampling_param.height = training_args.num_height
sampling_param.width = training_args.num_width

Comment on lines +20 to +22
NUM_GPUS=8

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script name finetune_t2v_grpo_4gpu.sh and the comment on line 2 suggest this is for 4 GPUs, but NUM_GPUS is set to 8 and CUDA_VISIBLE_DEVICES is set to use 8 GPUs (0-7). This is inconsistent and misleading. Please align the script's configuration with its name and documentation.

# epoch = epoch*num_batches_per_epoch+i and skips the first 2 epochs, so first real batch
# uses seed + 2*num_batches_per_epoch; we use seed+step+4 so step 0,1,... matches that.
g = torch.Generator()
g.manual_seed(self.seed + self.step + 3)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on lines 85-87 states that the seed is calculated as seed+step+4 to align with flow_grpo, but the code uses self.seed + self.step + 3. This discrepancy could lead to confusion and incorrect behavior if the comment is trusted. Please update either the code or the comment to ensure they are consistent.

Suggested change
g.manual_seed(self.seed + self.step + 3)
g.manual_seed(self.seed + self.step + 4)

test_num_workers: int = 8,
num_replicas: int = 1,
rank: int = 0,
) -> tuple[DataLoader, DataLoader]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for build_rl_prompt_dataloader is -> tuple[DataLoader, DataLoader], but the function actually returns a tuple of five elements: (train_dataloader, test_dataloader, train_dataset, test_dataset, train_sampler). Please update the type hint to match the implementation for clarity and correctness.

Suggested change
) -> tuple[DataLoader, DataLoader]:
) -> tuple[DataLoader, DataLoader, Dataset, Dataset, Sampler]:

for f in TRAINING_BATCH_SAMPLE_TENSOR_FIELDS:
t = getattr(batch, f)
if t is not None:
setattr(sub, f, t[s:e].clone() if t.is_cuda else t[s:e])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an inconsistency in how tensors are handled based on their device. CUDA tensors are cloned, while CPU tensors are not, creating a view. This can lead to unexpected side effects if the sliced CPU tensor is modified elsewhere. For safety and consistency, it's better to always clone the tensor regardless of its device.

Suggested change
setattr(sub, f, t[s:e].clone() if t.is_cuda else t[s:e])
setattr(sub, f, t[s:e].clone())

Comment thread fastvideo/training/rl/rewards/ocr.py Outdated
Comment on lines +85 to +86
except Exception as e:
dist=len(prompt)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception without logging the specific error can hide underlying issues and make debugging difficult. It's better to log the exception to aid in troubleshooting.

Suggested change
except Exception as e:
dist=len(prompt)
except Exception as e:
logger.warning(f"OCR failed for a frame: {e}")
dist=len(prompt)

Comment on lines +57 to +59
assert abs(sum(reward_weights) - 1.0) < 1e-6, \
f"Reward weights must sum to 1.0, got {sum(reward_weights)}"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code asserts that reward weights must sum to 1.0, which places the burden of normalization on the user. It would be more robust and user-friendly to normalize the weights internally if they don't already sum to 1.0. This would prevent unexpected crashes and simplify configuration.

@shijiew555 shijiew555 changed the title [feat] RL Pipeline [feat] Add Wan2.1 RL Pipeline Apr 9, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 9, 2026

Pre-commit checks failed

Hi @shijiew555, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify mergify bot added the scope: docs Documentation label Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: data Data preprocessing, datasets scope: docs Documentation scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build scope: model Model architecture (DiTs, encoders, VAEs) scope: training Training pipeline, methods, configs type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants