Skip to content

[feat] Cosmos 2.5 training support#1224

Open
alexzms wants to merge 1 commit intohao-ai-lab:mainfrom
FoundationResearch:feat/train/cosmos
Open

[feat] Cosmos 2.5 training support#1224
alexzms wants to merge 1 commit intohao-ai-lab:mainfrom
FoundationResearch:feat/train/cosmos

Conversation

@alexzms
Copy link
Copy Markdown
Collaborator

@alexzms alexzms commented Apr 8, 2026

Summary

Add Cosmos 2.5 (Predict2.5-2B) model plugin for fastvideo/train framework
Add preprocessing pipeline and overfit config for Cosmos 2.5
Fix flow-matching noise schedule, FSDP dtype handling, and VAE normalization'

Test

Overfit test: 1 GPU, 480×832, 93 frames, 1000 steps
Loss: 0.075 → 0.057, grad norm stable ~0.43
Validation videos verified clean at steps 0, 150, 300, 500, 1000

Add Cosmos 2.5 (Predict2.5-2B) model plugin, preprocessing pipeline,
and overfit config for the fastvideo/train framework.

- CosmosModel plugin inheriting from WanModel with flow-matching noise
  schedule and velocity prediction
- Preprocessing script for Cosmos 2.5 (Wan VAE + Reason1 text encoder)
- Overfit training config (480x832, 93 frames, single GPU)
- Fix FSDP dtype detection in Cosmos25DenoisingStage
- Extend normalize_dit_input for Cosmos25WanVAE
- Fix validation callback for Cosmos inference compatibility
@alexzms alexzms requested a review from SolitaryThinker April 8, 2026 20:00
@mergify mergify bot added type: feat New feature or capability scope: training Training pipeline, methods, configs scope: inference Inference pipeline, serving, CLI scope: data Data preprocessing, datasets scope: model Model architecture (DiTs, encoders, VAEs) labels Apr 8, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 8, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for:

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 8, 2026

Pre-commit checks failed

Hi @alexzms, the pre-commit checks have failed. To fix them locally:

# Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install

# Run all checks and auto-fix what's possible
pre-commit run --all-files

Common fixes:

  • yapf: yapf -i <file> (formatting)
  • ruff: ruff check --fix <file> (linting)
  • codespell: codespell --write-changes <file> (spelling)

After fixing, commit and push the changes. The checks will re-run automatically.

For future commits, pre-commit will run automatically on changed files before each commit.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for Cosmos 2 and 2.5 models, including new training plugins, preprocessing scripts for overfitting tests, and specialized configuration files. Key architectural updates include adjusting input channels to accommodate condition masks, implementing manual EDM preconditioning within the denoising stage, and adding defensive dtype alignment in the transformer forward pass to ensure compatibility with FSDP-wrapped training. The review feedback identifies several areas for improvement: resolving redundant configuration parameters in YAML files, restoring or replacing error handling for unknown model attributes to prevent silent failures, refactoring verbose dtype alignment logic, and adhering to PEP 8 standards regarding imports and the use of constants instead of magic numbers.

_target_: fastvideo.train.models.cosmos.CosmosModel
init_from: KyleShao/Cosmos-Predict2.5-2B-Diffusers
trainable: true
enable_gradient_checkpointing_type: full
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter enable_gradient_checkpointing_type is specified here and also under training.model on line 64. This redundancy can be confusing and may lead to unexpected behavior if the values differ. It's best to define it in a single, authoritative location. I recommend removing this line and keeping the one under training.model.

@@ -48,8 +48,6 @@ def update_model_arch(self, source_model_dict: dict[str, Any]) -> None:
for key, value in source_model_dict.items():
if key in valid_fields:
setattr(arch_config, key, value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

By removing the error for unknown configuration keys, you risk that typos in parameter names go unnoticed, leading to silent failures where the model uses default values. This can be very hard to debug. Consider logging a warning for unknown keys to alert users of potential misconfigurations.

Comment on lines +630 to +643
if hidden_states.dtype != _target_dtype:
hidden_states = hidden_states.to(_target_dtype)
if condition_mask is not None and condition_mask.dtype != _target_dtype:
condition_mask = condition_mask.to(_target_dtype)
if padding_mask is not None and padding_mask.dtype != _target_dtype:
padding_mask = padding_mask.to(_target_dtype)
if isinstance(encoder_hidden_states, torch.Tensor):
if encoder_hidden_states.dtype != _target_dtype:
encoder_hidden_states = encoder_hidden_states.to(_target_dtype)
else:
encoder_hidden_states = [
t.to(_target_dtype) if t.dtype != _target_dtype else t
for t in encoder_hidden_states
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for defensive dtype alignment is a bit verbose. You can make it more concise by removing the unnecessary dtype checks, as to() is idempotent.

        hidden_states = hidden_states.to(_target_dtype)
        if condition_mask is not None:
            condition_mask = condition_mask.to(_target_dtype)
        if padding_mask is not None:
            padding_mask = padding_mask.to(_target_dtype)
        if isinstance(encoder_hidden_states, torch.Tensor):
            encoder_hidden_states = encoder_hidden_states.to(_target_dtype)
        else:
            encoder_hidden_states = [t.to(_target_dtype) for t in encoder_hidden_states]

Comment on lines +108 to +109
import glob
from safetensors.torch import load_file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should be at the top of the file, not inside a function. This is a standard Python style convention (PEP 8) that improves readability and makes dependencies clear. Please move import glob and from safetensors.torch import load_file to the top of the file.

print(f"\nWrote {len(records)} records to {output_path}")

# Extract first frame from first video as V2W conditioning image
import cv2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

cv2 is already imported at the top of the file (line 15). This re-import is redundant and should be removed.

Comment on lines +325 to +329
embed_dim = getattr(arch, "hidden_size", 100352)
else:
embed_dim = 100352

num_tokens = 512 # Reason1 default padding length
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values 100352 and 512 are hardcoded. It would be better to define them as constants at the module level, for example _REASON1_EMBED_DIM = 100352 and _REASON1_NUM_TOKENS = 512. This improves readability and maintainability, especially since these values are specific to the "Reason1" model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: data Data preprocessing, datasets scope: inference Inference pipeline, serving, CLI scope: model Model architecture (DiTs, encoders, VAEs) scope: training Training pipeline, methods, configs type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant