Skip to content

[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179

Open
H1yori233 wants to merge 57 commits intohao-ai-lab:mainfrom
H1yori233:feat/kaiqin/add-mg-df
Open

[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179
H1yori233 wants to merge 57 commits intohao-ai-lab:mainfrom
H1yori233:feat/kaiqin/add-mg-df

Conversation

@H1yori233
Copy link
Copy Markdown
Contributor

Purpose

Add diffusion forcing and ode init for MatrixGame 2.0

Changes

  • Introduce fastvideo/models/schedulers/scheduling_diffusion_forcing.py scheduler for diffusion forcing.
  • add fastvideo/training/matrixgame_ar_diffusion_pipeline.py diffusion forcing pipeline.
  • add schema fastvideo/dataset/dataloader/schema.py and pipeline fastvideo/training/matrixgame_ode_causal_pipeline.py for ODE Init.
  • fastvideo/training/training_pipeline.py now can upload reference video to WanDB for easily comparison.

Others

this PR is on previous training pipeline, changes for the new trainer will be delivered in another PR.

Checklist

  • I ran pre-commit run --all-files and fixed all issues

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the training capabilities for MatrixGame 2.0 by integrating advanced diffusion forcing and ODE initialization techniques. These additions provide new methodologies for model training, aiming to improve stability and performance. Furthermore, the validation process has been enhanced with the ability to log reference videos, offering clearer insights into model progression and enabling more direct comparisons.

Highlights

  • Diffusion Forcing Scheduler: Introduced a new DiffusionForcingScheduler to enable diffusion forcing in training workflows.
  • MatrixGame 2.0 Diffusion Forcing Pipeline: Added a dedicated matrixgame_ar_diffusion_pipeline.py for MatrixGame 2.0 diffusion forcing training.
  • ODE Initialization Schema and Pipeline: Implemented a new PyArrow schema and matrixgame_ode_causal_pipeline.py for ODE initialization, supporting new training methodologies.
  • Enhanced Validation Logging: Enhanced the training_pipeline.py to upload reference videos to WanDB, facilitating easier comparison during validation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new ODE trajectory preprocessing pipeline and a corresponding training pipeline for the MatrixGame 2.0 model. Key changes include defining new data schemas, implementing a DiffusionForcingScheduler, and integrating these components into the existing framework. Utility functions have been updated to support the new scheduler, and the validation logging mechanism now includes reference videos. Feedback from the review suggests refactoring repeated serialization logic, correcting docstring placements, removing dead code and commented-out sections, and addressing a hardcoded value in the training pipeline.

Comment on lines +232 to +295
record.update({
"clip_feature_bytes": clip_feature.tobytes(),
"clip_feature_shape": list(clip_feature.shape),
"clip_feature_dtype": str(clip_feature.dtype),
})

record.update({
"first_frame_latent_bytes": first_frame_latent.tobytes(),
"first_frame_latent_shape": list(first_frame_latent.shape),
"first_frame_latent_dtype": str(first_frame_latent.dtype),
})

# Optional PIL Image
if pil_image is not None:
record.update({
"pil_image_bytes": pil_image.tobytes(),
"pil_image_shape": list(pil_image.shape),
"pil_image_dtype": str(pil_image.dtype),
})
else:
record.update({
"pil_image_bytes": b"",
"pil_image_shape": [],
"pil_image_dtype": "",
})

# Actions
if keyboard_cond is not None:
record.update({
"keyboard_cond_bytes": keyboard_cond.tobytes(),
"keyboard_cond_shape": list(keyboard_cond.shape),
"keyboard_cond_dtype": str(keyboard_cond.dtype),
})
else:
record.update({
"keyboard_cond_bytes": b"",
"keyboard_cond_shape": [],
"keyboard_cond_dtype": "",
})

if mouse_cond is not None:
record.update({
"mouse_cond_bytes": mouse_cond.tobytes(),
"mouse_cond_shape": list(mouse_cond.shape),
"mouse_cond_dtype": str(mouse_cond.dtype),
})
else:
record.update({
"mouse_cond_bytes": b"",
"mouse_cond_shape": [],
"mouse_cond_dtype": "",
})

record.update({
"trajectory_latents_bytes": trajectory_latents.tobytes(),
"trajectory_latents_shape": list(trajectory_latents.shape),
"trajectory_latents_dtype": str(trajectory_latents.dtype),
})

record.update({
"trajectory_timesteps_bytes": trajectory_timesteps.tobytes(),
"trajectory_timesteps_shape": list(trajectory_timesteps.shape),
"trajectory_timesteps_dtype": str(trajectory_timesteps.dtype),
})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for serializing numpy arrays and adding them to the record dictionary is repeated for clip_feature, first_frame_latent, pil_image, keyboard_cond, mouse_cond, trajectory_latents, and trajectory_timesteps. This can be refactored into a helper function to reduce code duplication and improve maintainability.

    def _add_array(name: str, array: np.ndarray | None):
        if array is not None:
            record.update({
                f"{name}_bytes": array.tobytes(),
                f"{name}_shape": list(array.shape),
                f"{name}_dtype": str(array.dtype),
            })
        else:
            record.update({
                f"{name}_bytes": b"",
                f"{name}_shape": [],
                f"{name}_dtype": "",
            })

    # I2V features
    _add_array("clip_feature", clip_feature)
    _add_array("first_frame_latent", first_frame_latent)

    # Optional PIL Image
    _add_array("pil_image", pil_image)

    # Actions
    _add_array("keyboard_cond", keyboard_cond)
    _add_array("mouse_cond", mouse_cond)

    _add_array("trajectory_latents", trajectory_latents)
    _add_array("trajectory_timesteps", trajectory_timesteps)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to do this?

Comment thread fastvideo/training/matrixgame_ode_causal_pipeline.py Outdated
Comment on lines +190 to +215
def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch:
"""Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning."""

# First, call parent method to prepare noise, timesteps, etc. for video latents
training_batch = super()._prepare_dit_inputs(training_batch)

assert isinstance(training_batch.image_latents, torch.Tensor)
image_latents = training_batch.image_latents.to(get_local_torch_device(), dtype=torch.bfloat16)

temporal_compression_ratio = 4
num_frames = (self.training_args.num_latent_t - 1) * temporal_compression_ratio + 1
batch_size, num_channels, _, latent_height, latent_width = image_latents.shape
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, 1:] = 0

first_frame_mask = mask_lat_size[:, :, :1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_compression_ratio)
mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2)
mask_lat_size = mask_lat_size.view(batch_size, -1, temporal_compression_ratio, latent_height, latent_width)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(image_latents.device).to(dtype=torch.bfloat16)

training_batch.noisy_model_input = torch.cat([training_batch.noisy_model_input, mask_lat_size, image_latents],
dim=1)

return training_batch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _prepare_dit_inputs method appears to be unused within this class, as train_one_step does not call it. This seems to be dead code and should be removed to improve maintainability.

Comment thread fastvideo/training/matrixgame_ode_causal_pipeline.py Outdated
Comment thread fastvideo/training/matrixgame_ode_causal_pipeline.py Outdated
@H1yori233
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new ODE trajectory preprocessing pipeline and a corresponding training pipeline for MatrixGame models. It adds a DiffusionForcingScheduler and integrates it into the training process, along with updates to utility functions to support this new scheduler. The changes also include a new shell script for finetuning, a PyArrow schema for ODE trajectory data, and enhancements to the validation logging to include reference videos. Review comments suggest refactoring duplicated code, optimizing timestep calculation, and correcting misleading comments in the code.

Comment on lines +244 to +283
# Optional PIL Image
if pil_image is not None:
record.update({
"pil_image_bytes": pil_image.tobytes(),
"pil_image_shape": list(pil_image.shape),
"pil_image_dtype": str(pil_image.dtype),
})
else:
record.update({
"pil_image_bytes": b"",
"pil_image_shape": [],
"pil_image_dtype": "",
})

# Actions
if keyboard_cond is not None:
record.update({
"keyboard_cond_bytes": keyboard_cond.tobytes(),
"keyboard_cond_shape": list(keyboard_cond.shape),
"keyboard_cond_dtype": str(keyboard_cond.dtype),
})
else:
record.update({
"keyboard_cond_bytes": b"",
"keyboard_cond_shape": [],
"keyboard_cond_dtype": "",
})

if mouse_cond is not None:
record.update({
"mouse_cond_bytes": mouse_cond.tobytes(),
"mouse_cond_shape": list(mouse_cond.shape),
"mouse_cond_dtype": str(mouse_cond.dtype),
})
else:
record.update({
"mouse_cond_bytes": b"",
"mouse_cond_shape": [],
"mouse_cond_dtype": "",
})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling optional numpy arrays (pil_image, keyboard_cond, mouse_cond) is duplicated. This can be refactored into a loop to improve maintainability and reduce boilerplate code.

Suggested change
# Optional PIL Image
if pil_image is not None:
record.update({
"pil_image_bytes": pil_image.tobytes(),
"pil_image_shape": list(pil_image.shape),
"pil_image_dtype": str(pil_image.dtype),
})
else:
record.update({
"pil_image_bytes": b"",
"pil_image_shape": [],
"pil_image_dtype": "",
})
# Actions
if keyboard_cond is not None:
record.update({
"keyboard_cond_bytes": keyboard_cond.tobytes(),
"keyboard_cond_shape": list(keyboard_cond.shape),
"keyboard_cond_dtype": str(keyboard_cond.dtype),
})
else:
record.update({
"keyboard_cond_bytes": b"",
"keyboard_cond_shape": [],
"keyboard_cond_dtype": "",
})
if mouse_cond is not None:
record.update({
"mouse_cond_bytes": mouse_cond.tobytes(),
"mouse_cond_shape": list(mouse_cond.shape),
"mouse_cond_dtype": str(mouse_cond.dtype),
})
else:
record.update({
"mouse_cond_bytes": b"",
"mouse_cond_shape": [],
"mouse_cond_dtype": "",
})
# Optional PIL Image and Actions
for prefix, array in [("pil_image", pil_image), ("keyboard_cond", keyboard_cond), ("mouse_cond", mouse_cond)]:
if array is not None:
record.update({
f"{prefix}_bytes": array.tobytes(),
f"{prefix}_shape": list(array.shape),
f"{prefix}_dtype": str(array.dtype),
})
else:
record.update({
f"{prefix}_bytes": b"",
f"{prefix}_shape": [],
f"{prefix}_dtype": "",
})

Comment on lines +97 to +100
timestep_id = torch.argmin(
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(),
dim=1,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.argmin to find the timestep_id can be inefficient, especially if self.timesteps is large, as it has a time complexity of O(N). Since self.timesteps is a sorted tensor (monotonically decreasing), you can use torch.searchsorted for a more performant O(logN) lookup.

To use torch.searchsorted, you'll need to work with a monotonically increasing tensor. You could either flip self.timesteps before searching or adjust its creation in set_timesteps to be ascending and then flip it for the parts of the code that expect a descending order.


features["clip_feature"] = clip_features
"""Get VAE features from the first frame of each video"""
# Get CLIP features from the first frame of each video.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is misleading. The code block that follows is responsible for calculating VAE features, not CLIP features. CLIP features are calculated in the block just before this comment.

Suggested change
# Get CLIP features from the first frame of each video.
# Get VAE features from the first frame of each video.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this should be VAE, you can fix it

self.get_module("vae").to(get_local_torch_device())

features = {}
"""Get CLIP features from the first frame of each video."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a docstring placed in the middle of a function body. According to Python style guides (like PEP 257), docstrings should only appear as the first statement in a module, function, class, or method definition. For comments within a function, please use the # syntax.

Suggested change
"""Get CLIP features from the first frame of each video."""
# Get CLIP features from the first frame of each video.


features["clip_feature"] = clip_features
features["pil_image"] = first_frame
# Get CLIP features from the first frame of each video.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is misleading. The code block that follows is responsible for preparing video_conditions to be encoded by the VAE, not for getting CLIP features. CLIP features are calculated in the block before this.

Suggested change
# Get CLIP features from the first frame of each video.
# Get VAE features from the first frame of each video.

logger.info("relevant_traj_latents: %s", relevant_traj_latents.shape)

indexes = self._get_timestep( # [B, num_frames]
0, len(self.dmd_denoising_steps), B, num_frames, 3, uniform_timestep=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_frame_per_block argument is hardcoded as 3. This should be a configurable parameter, similar to how it's handled in MatrixGameARDiffusionPipeline, to improve flexibility and maintain consistency across pipelines. Consider making it a class attribute initialized from the training arguments.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just use global_step == self.init_steps to check ref video logged or not, instead of self.validation_ref_videos_logged = False at init.

@mergify mergify bot added type: feat New feature or capability scope: training Training pipeline, methods, configs scope: data Data preprocessing, datasets scope: model Model architecture (DiTs, encoders, VAEs) labels Mar 30, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 30, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for:

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • check-success=fastcheck-passed
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 30, 2026

Buildkite CI tests failed

Hi @H1yori233, some Buildkite CI tests have failed. Check the build for details:
View Buildkite build →

Common causes:

  • Test failures: Check the failing step's output for assertion errors or tracebacks
  • Import errors: Make sure new dependencies are added to pyproject.toml
  • GPU memory: Some tests require specific GPU types (L40S, H100 NVL)
  • Kernel build: If you changed fastvideo-kernel/, the build may have failed

If the failure is unrelated to your changes, leave a comment explaining why.

Comment thread fastvideo/models/utils.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sigma_from_timestep needed? Seems timesteps are all precomputed, and the sigmas are stored in scheduler.sigmas based on self.timesteps.

@alexzms alexzms self-requested a review March 31, 2026 22:45
alexzms and others added 5 commits March 31, 2026 23:04
Reverts the DiffusionForcingScheduler dependency in dfsft.py.
The original design using student's scheduler is more generic
and works with both flow matching and DDPM models.
Add per-frame gaussian loss weighting matching Causal-Forcing's
bsmntw scheme. Mid-noise timesteps get higher weight, extremes
(near-clean and pure-noise) get lower weight.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: data Data preprocessing, datasets scope: model Model architecture (DiTs, encoders, VAEs) scope: training Training pipeline, methods, configs type: feat New feature or capability

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants