Skip to content

Added SD3.5 training pipeline#1150

Open
Ishxn20 wants to merge 2 commits intohao-ai-lab:mainfrom
Ishxn20:sd3.5-training-pipeline
Open

Added SD3.5 training pipeline#1150
Ishxn20 wants to merge 2 commits intohao-ai-lab:mainfrom
Ishxn20:sd3.5-training-pipeline

Conversation

@Ishxn20
Copy link
Copy Markdown
Contributor

@Ishxn20 Ishxn20 commented Mar 2, 2026

No description provided.

Copilot AI review requested due to automatic review settings March 2, 2026 09:27
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the framework's capabilities by integrating a full training pipeline for Stable Diffusion 3.5. It introduces all necessary components, from data preprocessing that accommodates SD3.5's complex text encoders and VAE outputs, to the core training logic and validation procedures. This enables users to fine-tune SD3.5 models within the existing system, enhancing its versatility for advanced diffusion model research and application.

Highlights

  • SD3.5 Training Pipeline: Introduced a complete training pipeline for Stable Diffusion 3.5, including dedicated preprocessing, training, and validation components, enabling fine-tuning of SD3.5 models.
  • Multi-Encoder Text Conditioning: Implemented specific preprocessing logic for SD3.5's unique multi-encoder text conditioning, handling CLIP-L, CLIP-G, and T5-XXL outputs, and combining them into a unified format for training.
  • Data Schema and Batch Handling: Defined a new PyArrow schema (pyarrow_schema_sd35) to store SD3.5 specific data, such as VAE latents, combined text embeddings, and pooled CLIP projections, and updated batch information to include these new fields.
  • CLI Argument Standardization: Standardized command-line argument names in fastvideo_args.py by replacing underscores with hyphens for consistency.
  • Distributed Checkpointing Robustness: Improved the robustness of distributed checkpoint saving by adding error handling for dcp.save operations, particularly for single-rank scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/training/finetune/sd35/validation.json
    • Added a new validation JSON file with example captions and parameters for SD3.5.
  • fastvideo/dataset/dataloader/schema.py
    • Added pyarrow_schema_sd35 to define the data structure for SD3.5 training, including VAE latents, combined text embeddings, and pooled CLIP projections.
  • fastvideo/fastvideo_args.py
    • Updated CLI argument names for weighting_scheme, logit_mean, logit_std, and mode_scale to use hyphens instead of underscores for consistency.
  • fastvideo/pipelines/basic/sd35/sd35_pipeline.py
    • Imported LoRAPipeline to enable LoRA support.
    • Modified SD35Pipeline to inherit from LoRAPipeline, adding LoRA capabilities to the SD3.5 pipeline.
  • fastvideo/pipelines/pipeline_batch_info.py
    • Added pooled_projections field to the TrainingBatch dataclass to accommodate SD3.5's text conditioning.
  • fastvideo/pipelines/preprocess/preprocess_pipeline_sd35.py
    • Added a new preprocessing pipeline (PreprocessPipeline_SD35) for SD3.5, handling VAE encoding and three-encoder text conditioning (CLIP-L, CLIP-G, T5-XXL).
    • Implemented custom logic to store VAE latents with a temporal dimension of 1 for SD3.5 images.
  • fastvideo/pipelines/preprocess/v1_preprocess.py
    • Imported PreprocessPipeline_SD35.
    • Added 'sd35' as a valid option for the --preprocess_task argument.
    • Configured vae_precision to 'fp32' when preprocess_task is 'sd35'.
  • fastvideo/pipelines/stages/sd35_conditioning.py
    • Added padding logic to ensure prompt_embeds and neg_prompt_embeds have matching sequence lengths for CFG.
    • Ensured hidden_states, encoder_hidden_states, and pooled_projections are cast to the correct target_dtype before transformer forward pass.
    • Implemented VAE device management, moving it to the GPU for decoding and offloading to CPU if vae_cpu_offload is enabled.
  • fastvideo/tests/training/sd35/test_sd35_training.py
    • Added comprehensive synthetic end-to-end tests for the SD3.5 training pipeline, covering schema validation, parquet data roundtrip, input normalization, noisy input generation, input kwargs construction, transformer forward passes, and loss computation.
  • fastvideo/training/sd35_training_pipeline.py
    • Added SD35TrainingPipeline to manage the end-to-end training process for SD3.5 models.
    • Configured the training pipeline to use pyarrow_schema_sd35 for dataset loading.
    • Implemented specific logic for initializing the scheduler and handling text encoder sequence lengths for SD3.5.
    • Integrated a validation pipeline for SD3.5, leveraging the SD35Pipeline for inference.
    • Defined methods for preparing DIT inputs, including squeezing the temporal dimension of latents and building input keyword arguments with pooled projections.
  • fastvideo/training/training_pipeline.py
    • Added foreach=False to AdamW and AdamW8bit optimizer initializations, potentially addressing distributed training compatibility or performance.
  • fastvideo/training/training_utils.py
    • Wrapped dcp.save with a try-except block to gracefully handle distributed checkpoint saving failures, especially in single-rank setups.
    • Extended normalize_dit_input to support 'sd3' model type, applying VAE's shift_factor and scaling_factor.
  • scripts/finetune/finetune_sd35.sh
    • Added a new shell script to facilitate fine-tuning of Stable Diffusion 3.5 Medium models.
  • scripts/preprocess/v1_preprocess_sd35_data.sh
    • Added a new shell script for preprocessing images for SD3.5 fine-tuning, including VAE and multi-encoder text encoding.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive training pipeline for SD3.5, including data preprocessing, a new training pipeline class, corresponding tests, and execution scripts. The changes are well-structured and cover the full workflow. I've identified a critical bug in the preprocessing script where image width and height are swapped during metadata creation. Additionally, I've provided several medium-severity suggestions to improve maintainability, flexibility, and correctness, such as fixing incorrect command-line arguments, making data types more configurable, and clarifying optimizer settings. Overall, this is a solid contribution that will be even better with these fixes.

Comment on lines +322 to +325
"width":
int(pixel.shape[-2]),
"height":
int(pixel.shape[-1]),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

It appears that width and height are swapped here. pixel.shape[-1] corresponds to the width and pixel.shape[-2] to the height of the image tensor. This should be corrected to ensure the metadata is accurate.

Suggested change
"width":
int(pixel.shape[-2]),
"height":
int(pixel.shape[-1]),
"width":
int(pixel.shape[-1]),
"height":
int(pixel.shape[-2]),

Comment on lines +39 to +47
def initialize_training_pipeline(
self, training_args: TrainingArgs) -> None:
original_text_len = (training_args.pipeline_config.text_encoder_configs[
0].arch_config.text_len)
training_args.pipeline_config.text_encoder_configs[
0].arch_config.text_len = SD35_TEXT_SEQ_LEN
super().initialize_training_pipeline(training_args)
training_args.pipeline_config.text_encoder_configs[
0].arch_config.text_len = original_text_len
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This approach of temporarily modifying the training_args object, calling the superclass method, and then restoring the original value can be brittle. If the base class implementation changes, this could lead to unexpected behavior. Consider if the base class's initialize_training_pipeline could be modified to accept text_len as an optional argument to make this more robust. If not, adding a comment explaining why this is necessary would be helpful for future maintenance.

Comment on lines +92 to +101
# SD3.5 latents are stored as (C, T, H, W) with T=1
latents = batch["vae_latent"].to(device, dtype=torch.bfloat16)

# Combined CLIP+T5 sequence embeddings
encoder_hidden_states = batch["text_embedding"].to(
device, dtype=torch.bfloat16)

# Concatenated CLIP pooled projections
pooled_projections = batch["pooled_projection"].to(
device, dtype=torch.bfloat16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The data type for tensors is hardcoded to torch.bfloat16. This could be made more flexible by respecting the mixed_precision argument from training_args. For instance, you could determine the dtype based on self.training_args.mixed_precision.

Suggested change
# SD3.5 latents are stored as (C, T, H, W) with T=1
latents = batch["vae_latent"].to(device, dtype=torch.bfloat16)
# Combined CLIP+T5 sequence embeddings
encoder_hidden_states = batch["text_embedding"].to(
device, dtype=torch.bfloat16)
# Concatenated CLIP pooled projections
pooled_projections = batch["pooled_projection"].to(
device, dtype=torch.bfloat16)
dtype = torch.bfloat16 if self.training_args.mixed_precision == "bf16" else torch.float32
# SD3.5 latents are stored as (C, T, H, W) with T=1
latents = batch["vae_latent"].to(device, dtype=dtype)
# Combined CLIP+T5 sequence embeddings
encoder_hidden_states = batch["text_embedding"].to(
device, dtype=dtype)
# Concatenated CLIP pooled projections
pooled_projections = batch["pooled_projection"].to(
device, dtype=dtype)

betas=betas,
weight_decay=training_args.weight_decay,
eps=1e-8,
foreach=False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The foreach=False argument is set for the AdamW optimizer. While the default foreach=None often enables a faster implementation on CUDA, explicitly disabling it might be intentional for stability or to work around a known issue (e.g., with FSDP). It would be beneficial to add a comment explaining why this non-default option is used for future maintainability.

Comment on lines +55 to +57
--weighting_scheme "logit_normal" \
--logit_mean 0.0 \
--logit_std 1.0 \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The command-line arguments --weighting_scheme, --logit_mean, and --logit_std use snake_case. However, they were recently changed to use kebab-case (--weighting-scheme, --logit-mean, --logit-std) in fastvideo_args.py. Please update the script to use the correct argument names to avoid them being ignored or causing errors.

Suggested change
--weighting_scheme "logit_normal" \
--logit_mean 0.0 \
--logit_std 1.0 \
--weighting-scheme "logit_normal" \
--logit-mean 0.0 \
--logit-std 1.0 \

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a complete Stable Diffusion 3.5 fine-tuning pipeline to FastVideo, enabling text-to-image fine-tuning with the three-encoder (CLIP-L + CLIP-G + T5-XXL) conditioning architecture. It integrates into the existing framework alongside other training pipelines (Wan, MatrixGame, LTX2).

Changes:

  • New SD3.5 preprocessing pipeline (PreprocessPipeline_SD35) and parquet schema (pyarrow_schema_sd35) for encoding VAE latents and tri-encoder text embeddings offline
  • New SD35TrainingPipeline with custom flow-matching training loop, VAE normalization for SD3 (normalize_dit_input("sd3", ...)), and validation pipeline integration
  • Supporting changes: LoRA support in SD35Pipeline, VAE CPU offload in SD35DecodingStage, prompt embedding padding for CFG, foreach=False in AdamW optimizers, and a DCP save fallback for single-rank training

Reviewed changes

Copilot reviewed 14 out of 15 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
scripts/preprocess/v1_preprocess_sd35_data.sh Shell script to run SD3.5 VAE + text encoding preprocessing
scripts/finetune/finetune_sd35.sh Shell script to launch SD3.5 fine-tuning via torchrun
fastvideo/training/training_utils.py Adds SD3 latent normalization; adds DCP save error handling for single-rank mode
fastvideo/training/training_pipeline.py Adds foreach=False to both AdamW optimizers
fastvideo/training/sd35_training_pipeline.py New SD3.5 training pipeline class and main() entry point
fastvideo/tests/training/sd35/test_sd35_training.py Synthetic end-to-end tests for the SD3.5 pipeline
fastvideo/tests/training/sd35/__init__.py New test package init file
fastvideo/pipelines/stages/sd35_conditioning.py Adds CFG sequence-length padding; casts inputs to target dtype; adds VAE CPU offload in decoding stage
fastvideo/pipelines/preprocess/v1_preprocess.py Registers sd35 task and PreprocessPipeline_SD35
fastvideo/pipelines/preprocess/preprocess_pipeline_sd35.py New SD3.5 preprocessing pipeline with tri-encoder text encoding
fastvideo/pipelines/pipeline_batch_info.py Adds pooled_projections field to TrainingBatch
fastvideo/pipelines/basic/sd35/sd35_pipeline.py Adds LoRAPipeline to SD35Pipeline MRO for LoRA support
fastvideo/fastvideo_args.py Renames --weighting_scheme, --logit_mean, --logit_std, --mode_scale args to use dashes
fastvideo/dataset/dataloader/schema.py New pyarrow_schema_sd35 parquet schema definition
examples/training/finetune/sd35/validation.json Sample validation prompts for SD3.5 fine-tuning

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


if __name__ == "__main__":
argv = sys.argv
from fastvideo.fastvideo_args import TrainingArgs
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The from fastvideo.fastvideo_args import TrainingArgs at line 249 is a redundant import, since TrainingArgs is already imported at the module level (line 11). This dead import inside the if __name__ == "__main__" block adds noise without any effect.

Suggested change
from fastvideo.fastvideo_args import TrainingArgs

Copilot uses AI. Check for mistakes.


if __name__ == "__main__":
argv = sys.argv
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argv = sys.argv assignment at line 248 assigns the result to a local variable that is never read afterwards. The variable is unused, and sys.argv is read by parser.parse_args() automatically. This dead assignment exists consistently in other training pipelines, but it adds clutter here as well.

Suggested change
argv = sys.argv

Copilot uses AI. Check for mistakes.
self.dataset_writer.append_table(table)
logger.info("Collected batch with %s samples", len(table))

if num_processed >= args.flush_frequency:
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 346, self.dataset_writer.flush() is called unconditionally if num_processed >= args.flush_frequency. However, self.dataset_writer is only created lazily on the first non-empty batch (line 337-341). If num_processed somehow reaches flush_frequency before any batch_data is non-empty (edge case), this will raise an AttributeError. The flush should be guarded with if hasattr(self, "dataset_writer"), consistent with the final flush on line 351.

Suggested change
if num_processed >= args.flush_frequency:
if num_processed >= args.flush_frequency and hasattr(self, "dataset_writer"):

Copilot uses AI. Check for mistakes.
Comment on lines +155 to +166
except BaseException as e:
if isinstance(e, (KeyboardInterrupt, SystemExit)):
raise
if dist.is_initialized() and dist.get_world_size() == 1:
logger.warning(
"rank: %s, distributed checkpoint save failed in single-rank "
"mode; continuing with consolidated-only save: %s",
rank,
e,
)
else:
raise
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new exception handling silently suppresses ALL non-keyboard/system exceptions when dist.get_world_size() == 1. This includes OSError (disk full), PermissionError, and other serious I/O failures. These critical errors during checkpoint saving should not be silently logged and swallowed — they indicate that the checkpoint was not actually written, but training continues. If the run later fails, there will be no checkpoint to resume from. Consider re-raising all but a narrow set of known distributed-checkpoint-specific errors (e.g., RuntimeError with a specific message) instead of suppressing all exceptions in single-rank mode.

Copilot uses AI. Check for mistakes.
Comment on lines +69 to +78
def _log_validation(self, transformer, training_args,
global_step) -> None:
mp = getattr(training_args, "mixed_precision", None)
original_precision = training_args.pipeline_config.dit_precision
if mp and mp != "no":
training_args.pipeline_config.dit_precision = mp
try:
super()._log_validation(transformer, training_args, global_step)
finally:
training_args.pipeline_config.dit_precision = original_precision
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _log_validation, the dit_precision is set to mixed_precision (e.g., "bf16") for validation but the SD3.5 transformer is a float32 model during training (as configured by --dit_precision fp32). Setting dit_precision to "bf16" causes the validation pipeline to cast inputs to bfloat16, but the shared transformer weights remain in fp32. This mismatch relies on autocast handling the precision correctly. While this may work in practice with autocast enabled, the comment/documentation for this is unclear. It would be safer to either explicitly document this design decision or use "fp32" for validation to be consistent with the training precision.

Copilot uses AI. Check for mistakes.
Comment on lines +219 to +225
start_idx = 0
for root, _, files in os.walk(combined_parquet_dir):
for file in files:
if file.endswith(".parquet"):
table = pq.read_table(os.path.join(root, file))
start_idx += table.num_rows

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_idx is computed by counting existing parquet rows (lines 219-224) but is never referenced afterwards in the function. This computation is wasted work on every preprocessing run. If the intent is to resume from a checkpoint, the start_idx should actually be used (e.g., to skip already-processed samples); otherwise this dead code should be removed.

Suggested change
start_idx = 0
for root, _, files in os.walk(combined_parquet_dir):
for file in files:
if file.endswith(".parquet"):
table = pq.read_table(os.path.join(root, file))
start_idx += table.num_rows

Copilot uses AI. Check for mistakes.
training_batch = self._get_next_batch(training_batch)
training_batch = self._normalize_dit_input(training_batch)
training_batch = self._prepare_dit_inputs(training_batch)
training_batch = self._build_attention_metadata(training_batch)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After _prepare_dit_inputs, training_batch.raw_latent_shape is a 4D tuple (B, C, H, W) (set at line 155). The inherited _build_attention_metadata in the parent class accesses latents_shape[2:5] when VIDEO_SPARSE_ATTN or VMOBA_ATTN is enabled — for a 4D shape this yields only 2 elements (H, W) instead of the expected 3 (T, H, W), which will cause errors when those attention backends are used. This should be documented, and ideally _build_attention_metadata should be overridden to handle the 4D case (or simply return attn_metadata = None for SD3.5 which does not use video-sparse attention).

Copilot uses AI. Check for mistakes.
@jzhang38
Copy link
Copy Markdown
Collaborator

jzhang38 commented Mar 2, 2026

Is this ready for review?

@Ishxn20
Copy link
Copy Markdown
Contributor Author

Ishxn20 commented Mar 2, 2026

Is this ready for review?

Yes

@mergify mergify bot added scope: training Training pipeline, methods, configs scope: inference Inference pipeline, serving, CLI scope: data Data preprocessing, datasets labels Apr 12, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 12, 2026

⚠️ PR title format required

Your PR title must start with a type tag in brackets. Examples:

  • [feat] Add new model support
  • [bugfix] Fix VAE tiling corruption
  • [refactor] Restructure training pipeline
  • [perf] Optimize attention kernel
  • [ci] Update test infrastructure
  • [docs] Add inference guide
  • [misc] Clean up configs
  • [new-model] Port Flux2 to FastVideo

Valid tags: feat, feature, bugfix, fix, refactor, perf, ci, doc, docs, misc, chore, kernel, new-model

Please update your PR title and the merge protection check will pass automatically.

@mergify mergify bot added the scope: infra CI, tests, Docker, build label Apr 12, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 12, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for:

  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]
This rule is failing.
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model)\]
  • check-success=fastcheck-passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scope: data Data preprocessing, datasets scope: inference Inference pipeline, serving, CLI scope: infra CI, tests, Docker, build scope: training Training pipeline, methods, configs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants