Skip to content

[Bug] Failed to distill Wan2.1-Fun-1.3B-InP #1229

@InuyashaLee

Description

@InuyashaLee

Describe the bug

I tried to distill Wan2.1-Fun-1.3B-InP using your dataset Wan-Syn_77x448x832_600k, however, the videos generated after distilling are nearly solid color. I tried lr=2e-6, max_train_steps=4000 and lr=1e-5, max_train_steps=10000 with 8x8 A800, but the result was the same. Here are the validation videos:

lr=2e-6:
step=0: https://github.com/user-attachments/assets/fab61e88-af68-4bef-9285-aad45058bdbd
step=500: https://github.com/user-attachments/assets/43227023-2d65-4d0e-b052-10c75151451d
step=4000: https://github.com/user-attachments/assets/834f95ea-f5ff-4123-961a-aee3068f9306

lr=1e-5:

step=0: https://github.com/user-attachments/assets/00cc5005-f680-4fee-ae49-9e3d3b47e485
step=1000: https://github.com/user-attachments/assets/d032fc5a-daea-4fa0-85b5-6955acda48db
step=10000: https://github.com/user-attachments/assets/a2873e08-2f41-4bd8-8e79-46d857dd867f

Reproduction

export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=offline
export TOKENIZERS_PARALLELISM=false
export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA

MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers"
REAL_SCORE_MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers"
FAKE_SCORE_MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers"
DATA_DIR="FastVideo/Wan-Syn_77x448x832_600k"
VALIDATION_DATASET_FILE="examples/training/finetune/Wan2.1-Fun-1.3B-InP/crush_smol/validation.json"

Training arguments

training_args=(
--tracker_project_name "wan_i2v_inp_dmd_600k"
--output_dir "checkpoints/wan_i2v_inp_dmd_600k"
--max_train_steps 4000
--train_batch_size 1
--train_sp_batch_size 1
--gradient_accumulation_steps 1
--num_latent_t 20
--num_height 480
--num_width 832
--num_frames 77
--enable_gradient_checkpointing_type "full"
)

Parallel arguments

parallel_args=(
--num_gpus 64
--sp_size 1
--tp_size 1
--hsdp_replicate_dim 64
--hsdp_shard_dim 1
)

Model arguments

model_args=(
--model_path $MODEL_PATH
--pretrained_model_name_or_path $MODEL_PATH
--real_score_model_path $REAL_SCORE_MODEL_PATH
--fake_score_model_path $FAKE_SCORE_MODEL_PATH
)

Dataset arguments

dataset_args=(
--data_path "$DATA_DIR"
--dataloader_num_workers 4
)

Validation arguments

validation_args=(
--log_validation
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--validation_steps 500
--validation_sampling_steps "3"
--validation_guidance_scale "6.0"
)

Optimizer arguments

optimizer_args=(
--learning_rate 2e-6
--mixed_precision "bf16"
--weight_only_checkpointing_steps 0
--training_state_checkpointing_steps 500
--weight_decay 0.01
--max_grad_norm 1.0
)

Miscellaneous arguments

miscellaneous_args=(
--inference_mode False
--checkpoints_total_limit 10
--training_cfg_rate 0.0
--dit_precision "fp32"
--ema_start_step 0
--flow_shift 8
--seed 1000
)

DMD arguments

dmd_args=(
--dmd_denoising_steps '1000,757,522'
--min_timestep_ratio 0.02
--max_timestep_ratio 0.98
--generator_update_interval 5
--real_score_guidance_scale 3.5
)

torchrun
--nnodes=$WORLD_SIZE
--nproc_per_node=$RESOURCE_GPU
--node_rank=$RANK
--rdzv_backend=c10d
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT
fastvideo/training/wan_i2v_distillation_pipeline.py
"${parallel_args[@]}"
"${model_args[@]}"
"${dataset_args[@]}"
"${training_args[@]}"
"${optimizer_args[@]}"
"${validation_args[@]}"
"${miscellaneous_args[@]}"
"${dmd_args[@]}"

Environment

8x8 A800
Python 3.12.12
fastvideo 0.1.7
cuda 12.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    installationInstallation and setup issuesscope: attentionAttention backends (VSA, STA, Flash, etc.)scope: dataData preprocessing, datasetsscope: docsDocumentationscope: inferenceInference pipeline, serving, CLIscope: modelModel architecture (DiTs, encoders, VAEs)scope: trainingTraining pipeline, methods, configs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions