Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the framework's capabilities by integrating a full training pipeline for Stable Diffusion 3.5. It introduces all necessary components, from data preprocessing that accommodates SD3.5's complex text encoders and VAE outputs, to the core training logic and validation procedures. This enables users to fine-tune SD3.5 models within the existing system, enhancing its versatility for advanced diffusion model research and application. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive training pipeline for SD3.5, including data preprocessing, a new training pipeline class, corresponding tests, and execution scripts. The changes are well-structured and cover the full workflow. I've identified a critical bug in the preprocessing script where image width and height are swapped during metadata creation. Additionally, I've provided several medium-severity suggestions to improve maintainability, flexibility, and correctness, such as fixing incorrect command-line arguments, making data types more configurable, and clarifying optimizer settings. Overall, this is a solid contribution that will be even better with these fixes.
| "width": | ||
| int(pixel.shape[-2]), | ||
| "height": | ||
| int(pixel.shape[-1]), |
There was a problem hiding this comment.
It appears that width and height are swapped here. pixel.shape[-1] corresponds to the width and pixel.shape[-2] to the height of the image tensor. This should be corrected to ensure the metadata is accurate.
| "width": | |
| int(pixel.shape[-2]), | |
| "height": | |
| int(pixel.shape[-1]), | |
| "width": | |
| int(pixel.shape[-1]), | |
| "height": | |
| int(pixel.shape[-2]), |
| def initialize_training_pipeline( | ||
| self, training_args: TrainingArgs) -> None: | ||
| original_text_len = (training_args.pipeline_config.text_encoder_configs[ | ||
| 0].arch_config.text_len) | ||
| training_args.pipeline_config.text_encoder_configs[ | ||
| 0].arch_config.text_len = SD35_TEXT_SEQ_LEN | ||
| super().initialize_training_pipeline(training_args) | ||
| training_args.pipeline_config.text_encoder_configs[ | ||
| 0].arch_config.text_len = original_text_len |
There was a problem hiding this comment.
This approach of temporarily modifying the training_args object, calling the superclass method, and then restoring the original value can be brittle. If the base class implementation changes, this could lead to unexpected behavior. Consider if the base class's initialize_training_pipeline could be modified to accept text_len as an optional argument to make this more robust. If not, adding a comment explaining why this is necessary would be helpful for future maintenance.
| # SD3.5 latents are stored as (C, T, H, W) with T=1 | ||
| latents = batch["vae_latent"].to(device, dtype=torch.bfloat16) | ||
|
|
||
| # Combined CLIP+T5 sequence embeddings | ||
| encoder_hidden_states = batch["text_embedding"].to( | ||
| device, dtype=torch.bfloat16) | ||
|
|
||
| # Concatenated CLIP pooled projections | ||
| pooled_projections = batch["pooled_projection"].to( | ||
| device, dtype=torch.bfloat16) |
There was a problem hiding this comment.
The data type for tensors is hardcoded to torch.bfloat16. This could be made more flexible by respecting the mixed_precision argument from training_args. For instance, you could determine the dtype based on self.training_args.mixed_precision.
| # SD3.5 latents are stored as (C, T, H, W) with T=1 | |
| latents = batch["vae_latent"].to(device, dtype=torch.bfloat16) | |
| # Combined CLIP+T5 sequence embeddings | |
| encoder_hidden_states = batch["text_embedding"].to( | |
| device, dtype=torch.bfloat16) | |
| # Concatenated CLIP pooled projections | |
| pooled_projections = batch["pooled_projection"].to( | |
| device, dtype=torch.bfloat16) | |
| dtype = torch.bfloat16 if self.training_args.mixed_precision == "bf16" else torch.float32 | |
| # SD3.5 latents are stored as (C, T, H, W) with T=1 | |
| latents = batch["vae_latent"].to(device, dtype=dtype) | |
| # Combined CLIP+T5 sequence embeddings | |
| encoder_hidden_states = batch["text_embedding"].to( | |
| device, dtype=dtype) | |
| # Concatenated CLIP pooled projections | |
| pooled_projections = batch["pooled_projection"].to( | |
| device, dtype=dtype) |
| betas=betas, | ||
| weight_decay=training_args.weight_decay, | ||
| eps=1e-8, | ||
| foreach=False, |
There was a problem hiding this comment.
The foreach=False argument is set for the AdamW optimizer. While the default foreach=None often enables a faster implementation on CUDA, explicitly disabling it might be intentional for stability or to work around a known issue (e.g., with FSDP). It would be beneficial to add a comment explaining why this non-default option is used for future maintainability.
| --weighting_scheme "logit_normal" \ | ||
| --logit_mean 0.0 \ | ||
| --logit_std 1.0 \ |
There was a problem hiding this comment.
The command-line arguments --weighting_scheme, --logit_mean, and --logit_std use snake_case. However, they were recently changed to use kebab-case (--weighting-scheme, --logit-mean, --logit-std) in fastvideo_args.py. Please update the script to use the correct argument names to avoid them being ignored or causing errors.
| --weighting_scheme "logit_normal" \ | |
| --logit_mean 0.0 \ | |
| --logit_std 1.0 \ | |
| --weighting-scheme "logit_normal" \ | |
| --logit-mean 0.0 \ | |
| --logit-std 1.0 \ |
There was a problem hiding this comment.
Pull request overview
This PR adds a complete Stable Diffusion 3.5 fine-tuning pipeline to FastVideo, enabling text-to-image fine-tuning with the three-encoder (CLIP-L + CLIP-G + T5-XXL) conditioning architecture. It integrates into the existing framework alongside other training pipelines (Wan, MatrixGame, LTX2).
Changes:
- New SD3.5 preprocessing pipeline (
PreprocessPipeline_SD35) and parquet schema (pyarrow_schema_sd35) for encoding VAE latents and tri-encoder text embeddings offline - New
SD35TrainingPipelinewith custom flow-matching training loop, VAE normalization for SD3 (normalize_dit_input("sd3", ...)), and validation pipeline integration - Supporting changes: LoRA support in
SD35Pipeline, VAE CPU offload inSD35DecodingStage, prompt embedding padding for CFG,foreach=Falsein AdamW optimizers, and a DCP save fallback for single-rank training
Reviewed changes
Copilot reviewed 14 out of 15 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
scripts/preprocess/v1_preprocess_sd35_data.sh |
Shell script to run SD3.5 VAE + text encoding preprocessing |
scripts/finetune/finetune_sd35.sh |
Shell script to launch SD3.5 fine-tuning via torchrun |
fastvideo/training/training_utils.py |
Adds SD3 latent normalization; adds DCP save error handling for single-rank mode |
fastvideo/training/training_pipeline.py |
Adds foreach=False to both AdamW optimizers |
fastvideo/training/sd35_training_pipeline.py |
New SD3.5 training pipeline class and main() entry point |
fastvideo/tests/training/sd35/test_sd35_training.py |
Synthetic end-to-end tests for the SD3.5 pipeline |
fastvideo/tests/training/sd35/__init__.py |
New test package init file |
fastvideo/pipelines/stages/sd35_conditioning.py |
Adds CFG sequence-length padding; casts inputs to target dtype; adds VAE CPU offload in decoding stage |
fastvideo/pipelines/preprocess/v1_preprocess.py |
Registers sd35 task and PreprocessPipeline_SD35 |
fastvideo/pipelines/preprocess/preprocess_pipeline_sd35.py |
New SD3.5 preprocessing pipeline with tri-encoder text encoding |
fastvideo/pipelines/pipeline_batch_info.py |
Adds pooled_projections field to TrainingBatch |
fastvideo/pipelines/basic/sd35/sd35_pipeline.py |
Adds LoRAPipeline to SD35Pipeline MRO for LoRA support |
fastvideo/fastvideo_args.py |
Renames --weighting_scheme, --logit_mean, --logit_std, --mode_scale args to use dashes |
fastvideo/dataset/dataloader/schema.py |
New pyarrow_schema_sd35 parquet schema definition |
examples/training/finetune/sd35/validation.json |
Sample validation prompts for SD3.5 fine-tuning |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| if __name__ == "__main__": | ||
| argv = sys.argv | ||
| from fastvideo.fastvideo_args import TrainingArgs |
There was a problem hiding this comment.
The from fastvideo.fastvideo_args import TrainingArgs at line 249 is a redundant import, since TrainingArgs is already imported at the module level (line 11). This dead import inside the if __name__ == "__main__" block adds noise without any effect.
| from fastvideo.fastvideo_args import TrainingArgs |
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| argv = sys.argv |
There was a problem hiding this comment.
The argv = sys.argv assignment at line 248 assigns the result to a local variable that is never read afterwards. The variable is unused, and sys.argv is read by parser.parse_args() automatically. This dead assignment exists consistently in other training pipelines, but it adds clutter here as well.
| argv = sys.argv |
| self.dataset_writer.append_table(table) | ||
| logger.info("Collected batch with %s samples", len(table)) | ||
|
|
||
| if num_processed >= args.flush_frequency: |
There was a problem hiding this comment.
On line 346, self.dataset_writer.flush() is called unconditionally if num_processed >= args.flush_frequency. However, self.dataset_writer is only created lazily on the first non-empty batch (line 337-341). If num_processed somehow reaches flush_frequency before any batch_data is non-empty (edge case), this will raise an AttributeError. The flush should be guarded with if hasattr(self, "dataset_writer"), consistent with the final flush on line 351.
| if num_processed >= args.flush_frequency: | |
| if num_processed >= args.flush_frequency and hasattr(self, "dataset_writer"): |
| except BaseException as e: | ||
| if isinstance(e, (KeyboardInterrupt, SystemExit)): | ||
| raise | ||
| if dist.is_initialized() and dist.get_world_size() == 1: | ||
| logger.warning( | ||
| "rank: %s, distributed checkpoint save failed in single-rank " | ||
| "mode; continuing with consolidated-only save: %s", | ||
| rank, | ||
| e, | ||
| ) | ||
| else: | ||
| raise |
There was a problem hiding this comment.
The new exception handling silently suppresses ALL non-keyboard/system exceptions when dist.get_world_size() == 1. This includes OSError (disk full), PermissionError, and other serious I/O failures. These critical errors during checkpoint saving should not be silently logged and swallowed — they indicate that the checkpoint was not actually written, but training continues. If the run later fails, there will be no checkpoint to resume from. Consider re-raising all but a narrow set of known distributed-checkpoint-specific errors (e.g., RuntimeError with a specific message) instead of suppressing all exceptions in single-rank mode.
| def _log_validation(self, transformer, training_args, | ||
| global_step) -> None: | ||
| mp = getattr(training_args, "mixed_precision", None) | ||
| original_precision = training_args.pipeline_config.dit_precision | ||
| if mp and mp != "no": | ||
| training_args.pipeline_config.dit_precision = mp | ||
| try: | ||
| super()._log_validation(transformer, training_args, global_step) | ||
| finally: | ||
| training_args.pipeline_config.dit_precision = original_precision |
There was a problem hiding this comment.
In _log_validation, the dit_precision is set to mixed_precision (e.g., "bf16") for validation but the SD3.5 transformer is a float32 model during training (as configured by --dit_precision fp32). Setting dit_precision to "bf16" causes the validation pipeline to cast inputs to bfloat16, but the shared transformer weights remain in fp32. This mismatch relies on autocast handling the precision correctly. While this may work in practice with autocast enabled, the comment/documentation for this is unclear. It would be safer to either explicitly document this design decision or use "fp32" for validation to be consistent with the training precision.
| start_idx = 0 | ||
| for root, _, files in os.walk(combined_parquet_dir): | ||
| for file in files: | ||
| if file.endswith(".parquet"): | ||
| table = pq.read_table(os.path.join(root, file)) | ||
| start_idx += table.num_rows | ||
|
|
There was a problem hiding this comment.
start_idx is computed by counting existing parquet rows (lines 219-224) but is never referenced afterwards in the function. This computation is wasted work on every preprocessing run. If the intent is to resume from a checkpoint, the start_idx should actually be used (e.g., to skip already-processed samples); otherwise this dead code should be removed.
| start_idx = 0 | |
| for root, _, files in os.walk(combined_parquet_dir): | |
| for file in files: | |
| if file.endswith(".parquet"): | |
| table = pq.read_table(os.path.join(root, file)) | |
| start_idx += table.num_rows |
| training_batch = self._get_next_batch(training_batch) | ||
| training_batch = self._normalize_dit_input(training_batch) | ||
| training_batch = self._prepare_dit_inputs(training_batch) | ||
| training_batch = self._build_attention_metadata(training_batch) |
There was a problem hiding this comment.
After _prepare_dit_inputs, training_batch.raw_latent_shape is a 4D tuple (B, C, H, W) (set at line 155). The inherited _build_attention_metadata in the parent class accesses latents_shape[2:5] when VIDEO_SPARSE_ATTN or VMOBA_ATTN is enabled — for a 4D shape this yields only 2 elements (H, W) instead of the expected 3 (T, H, W), which will cause errors when those attention backends are used. This should be documented, and ideally _build_attention_metadata should be overridden to handle the 4D case (or simply return attn_metadata = None for SD3.5 which does not use video-sparse attention).
|
Is this ready for review? |
Yes |
Made-with: Cursor
|
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for:
This rule is failing.
|
No description provided.