[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179
[feat] Add diffusion forcing / ode init for MatrixGame2.0#1179H1yori233 wants to merge 57 commits intohao-ai-lab:mainfrom
Conversation
Feat/kaiqin/mg overlay validation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the training capabilities for MatrixGame 2.0 by integrating advanced diffusion forcing and ODE initialization techniques. These additions provide new methodologies for model training, aiming to improve stability and performance. Furthermore, the validation process has been enhanced with the ability to log reference videos, offering clearer insights into model progression and enabling more direct comparisons. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new ODE trajectory preprocessing pipeline and a corresponding training pipeline for the MatrixGame 2.0 model. Key changes include defining new data schemas, implementing a DiffusionForcingScheduler, and integrating these components into the existing framework. Utility functions have been updated to support the new scheduler, and the validation logging mechanism now includes reference videos. Feedback from the review suggests refactoring repeated serialization logic, correcting docstring placements, removing dead code and commented-out sections, and addressing a hardcoded value in the training pipeline.
| record.update({ | ||
| "clip_feature_bytes": clip_feature.tobytes(), | ||
| "clip_feature_shape": list(clip_feature.shape), | ||
| "clip_feature_dtype": str(clip_feature.dtype), | ||
| }) | ||
|
|
||
| record.update({ | ||
| "first_frame_latent_bytes": first_frame_latent.tobytes(), | ||
| "first_frame_latent_shape": list(first_frame_latent.shape), | ||
| "first_frame_latent_dtype": str(first_frame_latent.dtype), | ||
| }) | ||
|
|
||
| # Optional PIL Image | ||
| if pil_image is not None: | ||
| record.update({ | ||
| "pil_image_bytes": pil_image.tobytes(), | ||
| "pil_image_shape": list(pil_image.shape), | ||
| "pil_image_dtype": str(pil_image.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "pil_image_bytes": b"", | ||
| "pil_image_shape": [], | ||
| "pil_image_dtype": "", | ||
| }) | ||
|
|
||
| # Actions | ||
| if keyboard_cond is not None: | ||
| record.update({ | ||
| "keyboard_cond_bytes": keyboard_cond.tobytes(), | ||
| "keyboard_cond_shape": list(keyboard_cond.shape), | ||
| "keyboard_cond_dtype": str(keyboard_cond.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "keyboard_cond_bytes": b"", | ||
| "keyboard_cond_shape": [], | ||
| "keyboard_cond_dtype": "", | ||
| }) | ||
|
|
||
| if mouse_cond is not None: | ||
| record.update({ | ||
| "mouse_cond_bytes": mouse_cond.tobytes(), | ||
| "mouse_cond_shape": list(mouse_cond.shape), | ||
| "mouse_cond_dtype": str(mouse_cond.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "mouse_cond_bytes": b"", | ||
| "mouse_cond_shape": [], | ||
| "mouse_cond_dtype": "", | ||
| }) | ||
|
|
||
| record.update({ | ||
| "trajectory_latents_bytes": trajectory_latents.tobytes(), | ||
| "trajectory_latents_shape": list(trajectory_latents.shape), | ||
| "trajectory_latents_dtype": str(trajectory_latents.dtype), | ||
| }) | ||
|
|
||
| record.update({ | ||
| "trajectory_timesteps_bytes": trajectory_timesteps.tobytes(), | ||
| "trajectory_timesteps_shape": list(trajectory_timesteps.shape), | ||
| "trajectory_timesteps_dtype": str(trajectory_timesteps.dtype), | ||
| }) |
There was a problem hiding this comment.
The logic for serializing numpy arrays and adding them to the record dictionary is repeated for clip_feature, first_frame_latent, pil_image, keyboard_cond, mouse_cond, trajectory_latents, and trajectory_timesteps. This can be refactored into a helper function to reduce code duplication and improve maintainability.
def _add_array(name: str, array: np.ndarray | None):
if array is not None:
record.update({
f"{name}_bytes": array.tobytes(),
f"{name}_shape": list(array.shape),
f"{name}_dtype": str(array.dtype),
})
else:
record.update({
f"{name}_bytes": b"",
f"{name}_shape": [],
f"{name}_dtype": "",
})
# I2V features
_add_array("clip_feature", clip_feature)
_add_array("first_frame_latent", first_frame_latent)
# Optional PIL Image
_add_array("pil_image", pil_image)
# Actions
_add_array("keyboard_cond", keyboard_cond)
_add_array("mouse_cond", mouse_cond)
_add_array("trajectory_latents", trajectory_latents)
_add_array("trajectory_timesteps", trajectory_timesteps)There was a problem hiding this comment.
I think there is no need to do this?
| def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch: | ||
| """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" | ||
|
|
||
| # First, call parent method to prepare noise, timesteps, etc. for video latents | ||
| training_batch = super()._prepare_dit_inputs(training_batch) | ||
|
|
||
| assert isinstance(training_batch.image_latents, torch.Tensor) | ||
| image_latents = training_batch.image_latents.to(get_local_torch_device(), dtype=torch.bfloat16) | ||
|
|
||
| temporal_compression_ratio = 4 | ||
| num_frames = (self.training_args.num_latent_t - 1) * temporal_compression_ratio + 1 | ||
| batch_size, num_channels, _, latent_height, latent_width = image_latents.shape | ||
| mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) | ||
| mask_lat_size[:, :, 1:] = 0 | ||
|
|
||
| first_frame_mask = mask_lat_size[:, :, :1] | ||
| first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_compression_ratio) | ||
| mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) | ||
| mask_lat_size = mask_lat_size.view(batch_size, -1, temporal_compression_ratio, latent_height, latent_width) | ||
| mask_lat_size = mask_lat_size.transpose(1, 2) | ||
| mask_lat_size = mask_lat_size.to(image_latents.device).to(dtype=torch.bfloat16) | ||
|
|
||
| training_batch.noisy_model_input = torch.cat([training_batch.noisy_model_input, mask_lat_size, image_latents], | ||
| dim=1) | ||
|
|
||
| return training_batch |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new ODE trajectory preprocessing pipeline and a corresponding training pipeline for MatrixGame models. It adds a DiffusionForcingScheduler and integrates it into the training process, along with updates to utility functions to support this new scheduler. The changes also include a new shell script for finetuning, a PyArrow schema for ODE trajectory data, and enhancements to the validation logging to include reference videos. Review comments suggest refactoring duplicated code, optimizing timestep calculation, and correcting misleading comments in the code.
| # Optional PIL Image | ||
| if pil_image is not None: | ||
| record.update({ | ||
| "pil_image_bytes": pil_image.tobytes(), | ||
| "pil_image_shape": list(pil_image.shape), | ||
| "pil_image_dtype": str(pil_image.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "pil_image_bytes": b"", | ||
| "pil_image_shape": [], | ||
| "pil_image_dtype": "", | ||
| }) | ||
|
|
||
| # Actions | ||
| if keyboard_cond is not None: | ||
| record.update({ | ||
| "keyboard_cond_bytes": keyboard_cond.tobytes(), | ||
| "keyboard_cond_shape": list(keyboard_cond.shape), | ||
| "keyboard_cond_dtype": str(keyboard_cond.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "keyboard_cond_bytes": b"", | ||
| "keyboard_cond_shape": [], | ||
| "keyboard_cond_dtype": "", | ||
| }) | ||
|
|
||
| if mouse_cond is not None: | ||
| record.update({ | ||
| "mouse_cond_bytes": mouse_cond.tobytes(), | ||
| "mouse_cond_shape": list(mouse_cond.shape), | ||
| "mouse_cond_dtype": str(mouse_cond.dtype), | ||
| }) | ||
| else: | ||
| record.update({ | ||
| "mouse_cond_bytes": b"", | ||
| "mouse_cond_shape": [], | ||
| "mouse_cond_dtype": "", | ||
| }) |
There was a problem hiding this comment.
The logic for handling optional numpy arrays (pil_image, keyboard_cond, mouse_cond) is duplicated. This can be refactored into a loop to improve maintainability and reduce boilerplate code.
| # Optional PIL Image | |
| if pil_image is not None: | |
| record.update({ | |
| "pil_image_bytes": pil_image.tobytes(), | |
| "pil_image_shape": list(pil_image.shape), | |
| "pil_image_dtype": str(pil_image.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| "pil_image_bytes": b"", | |
| "pil_image_shape": [], | |
| "pil_image_dtype": "", | |
| }) | |
| # Actions | |
| if keyboard_cond is not None: | |
| record.update({ | |
| "keyboard_cond_bytes": keyboard_cond.tobytes(), | |
| "keyboard_cond_shape": list(keyboard_cond.shape), | |
| "keyboard_cond_dtype": str(keyboard_cond.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| "keyboard_cond_bytes": b"", | |
| "keyboard_cond_shape": [], | |
| "keyboard_cond_dtype": "", | |
| }) | |
| if mouse_cond is not None: | |
| record.update({ | |
| "mouse_cond_bytes": mouse_cond.tobytes(), | |
| "mouse_cond_shape": list(mouse_cond.shape), | |
| "mouse_cond_dtype": str(mouse_cond.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| "mouse_cond_bytes": b"", | |
| "mouse_cond_shape": [], | |
| "mouse_cond_dtype": "", | |
| }) | |
| # Optional PIL Image and Actions | |
| for prefix, array in [("pil_image", pil_image), ("keyboard_cond", keyboard_cond), ("mouse_cond", mouse_cond)]: | |
| if array is not None: | |
| record.update({ | |
| f"{prefix}_bytes": array.tobytes(), | |
| f"{prefix}_shape": list(array.shape), | |
| f"{prefix}_dtype": str(array.dtype), | |
| }) | |
| else: | |
| record.update({ | |
| f"{prefix}_bytes": b"", | |
| f"{prefix}_shape": [], | |
| f"{prefix}_dtype": "", | |
| }) |
| timestep_id = torch.argmin( | ||
| (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), | ||
| dim=1, | ||
| ) |
There was a problem hiding this comment.
Using torch.argmin to find the timestep_id can be inefficient, especially if self.timesteps is large, as it has a time complexity of O(N). Since self.timesteps is a sorted tensor (monotonically decreasing), you can use torch.searchsorted for a more performant O(logN) lookup.
To use torch.searchsorted, you'll need to work with a monotonically increasing tensor. You could either flip self.timesteps before searching or adjust its creation in set_timesteps to be ascending and then flip it for the parts of the code that expect a descending order.
|
|
||
| features["clip_feature"] = clip_features | ||
| """Get VAE features from the first frame of each video""" | ||
| # Get CLIP features from the first frame of each video. |
There was a problem hiding this comment.
This comment is misleading. The code block that follows is responsible for calculating VAE features, not CLIP features. CLIP features are calculated in the block just before this comment.
| # Get CLIP features from the first frame of each video. | |
| # Get VAE features from the first frame of each video. |
There was a problem hiding this comment.
Yes this should be VAE, you can fix it
| self.get_module("vae").to(get_local_torch_device()) | ||
|
|
||
| features = {} | ||
| """Get CLIP features from the first frame of each video.""" |
There was a problem hiding this comment.
This appears to be a docstring placed in the middle of a function body. According to Python style guides (like PEP 257), docstrings should only appear as the first statement in a module, function, class, or method definition. For comments within a function, please use the # syntax.
| """Get CLIP features from the first frame of each video.""" | |
| # Get CLIP features from the first frame of each video. |
|
|
||
| features["clip_feature"] = clip_features | ||
| features["pil_image"] = first_frame | ||
| # Get CLIP features from the first frame of each video. |
There was a problem hiding this comment.
This comment is misleading. The code block that follows is responsible for preparing video_conditions to be encoded by the VAE, not for getting CLIP features. CLIP features are calculated in the block before this.
| # Get CLIP features from the first frame of each video. | |
| # Get VAE features from the first frame of each video. |
| logger.info("relevant_traj_latents: %s", relevant_traj_latents.shape) | ||
|
|
||
| indexes = self._get_timestep( # [B, num_frames] | ||
| 0, len(self.dmd_denoising_steps), B, num_frames, 3, uniform_timestep=False) |
There was a problem hiding this comment.
There was a problem hiding this comment.
How about just use global_step == self.init_steps to check ref video logged or not, instead of self.validation_ref_videos_logged = False at init.
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for:
This rule is failing.
|
Buildkite CI tests failedHi @H1yori233, some Buildkite CI tests have failed. Check the build for details: Common causes:
If the failure is unrelated to your changes, leave a comment explaining why. |
There was a problem hiding this comment.
Why is sigma_from_timestep needed? Seems timesteps are all precomputed, and the sigmas are stored in scheduler.sigmas based on self.timesteps.
Reverts the DiffusionForcingScheduler dependency in dfsft.py. The original design using student's scheduler is more generic and works with both flow matching and DDPM models.
Add per-frame gaussian loss weighting matching Causal-Forcing's bsmntw scheme. Mid-noise timesteps get higher weight, extremes (near-clean and pure-noise) get lower weight.
Purpose
Add diffusion forcing and ode init for MatrixGame 2.0
Changes
fastvideo/models/schedulers/scheduling_diffusion_forcing.pyscheduler for diffusion forcing.fastvideo/training/matrixgame_ar_diffusion_pipeline.pydiffusion forcing pipeline.fastvideo/dataset/dataloader/schema.pyand pipelinefastvideo/training/matrixgame_ode_causal_pipeline.pyfor ODE Init.fastvideo/training/training_pipeline.pynow can upload reference video to WanDB for easily comparison.Others
this PR is on previous training pipeline, changes for the new trainer will be delivered in another PR.
Checklist
pre-commit run --all-filesand fixed all issues