-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathoverfit_config.yaml
More file actions
432 lines (369 loc) · 20.7 KB
/
overfit_config.yaml
File metadata and controls
432 lines (369 loc) · 20.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
# VASA Overfit Configuration - For fast convergence testing
defaults:
- _self_
# Model architecture - REDUCED for faster convergence
model:
hidden_dim: 512
n_heads: 8
n_layers: 8 # Reduced to 1 for ultra-fast training (change to 2 or more for better quality)
dim_feedforward: 2048
dropout: 0.05
expression_dim: 128 # Expression embedding dimension from volumetric avatar
use_derived_warps: true # Use identity-conditioned warp generation via volumetric_avatar pipeline
# Add these:
condition_embedding_dim: 512 # Dimension for condition embeddings
use_prev_motion: true # Whether to use previous motion context
max_motion_length: 1000 # Maximum motion sequence length
use_relative_position: true # Use relative position encoding
# Audio projection settings (moved from channel_config.yaml)
use_talkvid_audio_projection: true # Set to true to use TalkVid-style Perceiver architecture instead of JoyVASA linear layer
# Diffusion parameters - SIMPLIFIED
diffusion:
num_steps: 1000 # Increased to 1000 for better sampling (was previously 50 in generation)
beta_start: 1e-4
beta_end: 0.02
cfg_start_epoch: 20 # When to start using CFG (reduced from 100 to prevent early over-conditioning)
schedule_mode: 'linear' # Changed from 'cosine' to fix DDIMScheduler error
schedule_s: 0.008 # Schedule parameters (for cosine schedule)
# Noise injection for motion variance (prevents collapse)
inject_noise_early_steps: true # Add extra noise in early diffusion steps to encourage variance
noise_injection_epochs: 30 # Inject noise for first N epochs
noise_injection_scale: 0.05 # Scale of additional noise (0.05 = 5% of signal)
# Motion generation parameters - SMALLER WINDOWS FOR TDD
motion:
num_speed_buckets: 9
window_size: 50 # Reduced from 50 to 20 for faster training
stride: 25 # Reduced from 25 to 10 for better overlap
context_size: 10
# Add these:
warmup_steps: 500 # Extended warmup for VASA-1 alignment
scheduler:
type: "cosine" # Learning rate scheduler type
min_lr: 1e-4 # Increased minimum LR to prevent decay too low
warmup_ratio: 0.1 # Extended warmup ratio for stability
# Progressive training
progressive:
enabled: false # Disabled for overfitting
start_sequence_length: 15
max_sequence_length: 50
length_increase_freq: 5 # Epochs between increases
# Mixed precision training
amp: true # Whether to use automatic mixed precision
# Training parameters - OVERFITTING FOCUSED
train:
resume_from: "" # Resume from best checkpoint
# when to add conditional signals
control_start_epoch: 0 # Start control immediately for overfitting
# Ground truth theta injection (for isolating expression learning)
use_gt_theta: true # Use ground truth theta from dataset instead of predicted (disables theta loss)
use_gt_scale: true # Use ground truth scale (only if use_gt_theta=true)
use_gt_rotation: true # Use ground truth rotation (only if use_gt_theta=true)
use_gt_translation: true # Use ground truth translation (only if use_gt_theta=true)
lr: 5e-4 # Base learning rate for overfitting
motion_proj_lr: 0.1 # DEPRECATED - see get_layer_wise_learning_rates in vasa_trainer.py
expression_lr: 0.1 # DEPRECATED - use expression_lr_mult instead
expression_lr_mult: 2.0 # Expression learning rate multiplier (2.0x = 1e-3 for faster expression learning)
turn_off_noise: false # no noise signal / no drop out
learning_rates:
condition_embedding: 1e-4 # Further reduced for stability
motion_projections: 1e-4 # Further reduced for stability
transformer_early: 1e-4 # Further reduced for stability
transformer_late: 1e-4 # Further reduced for stability
output_projections: 1e-4 # Further reduced for stability
# To increase GPU utilization:
# - Increase batch_size (e.g., from 28 to 56)
# - This processes more windows per step but maintains temporal continuity
batch_size: 8 # Batch size for training
windows_per_batch: 2 # Reduced from 14 to 4 for better sequence diversity
gradient_accumulation_steps: 1 # Increase effective batch size to break plateau
num_epochs: 4000 # Reduced for quick testing
beta1: 0.9
beta2: 0.999
weight_decay: 1e-5 # Reduced from 0.01 to allow more variance
max_grad_norm: 0.1 # Very aggressive gradient clipping to prevent explosions (was 0.5)
save_freq: 5 # Save very frequently to enable recovery from OOM
# Dropout probabilities for CFG - VASA-1 uses different rates for prev_ vs current
dropout_probs:
# Current conditions - AUDIO MUST NEVER BE DROPPED (it's the primary driving signal)
audio: 0.0 # NEVER drop audio - it drives lip motion and dynamics
gaze: 0.1 # Paper default: 0.1
head_distance: 0.1 # Paper default: 0.1
emotion: 0.1 # Paper default: 0.1
# Previous context conditions (0.5 dropout per paper for robustness)
prev_audio: 0.0 # NEVER drop previous audio - critical for temporal coherence
prev_theta: 0.5 # Paper: 0.5 for previous motion
prev_rotation: 0.5 # Paper: 0.5 for previous motion
prev_translation: 0.5 # Paper: 0.5 for previous motion
prev_expression: 0.5 # Paper: 0.5 for previous dynamics
# CFG scales schedule (aligned with VASA-1 paper)
cfg_scales:
audio: 0.5 # Paper default: 0.5 (was 3.0 - too high, causes over-conditioning)
gaze: 1.0 # Paper default: 1.0
head_distance: 0.8
emotion: 0.5
# Gradient monitoring
gradient_monitoring:
enabled: true
log_freq: 10 # More frequent logging
save_grad_norms: true
save_grad_flows: true
num_layers_to_monitor: 2 # Monitor all 2 layers
# Loss weights - TDD PROGRESSIVE SYSTEM
loss:
# Use TDD progressive loss system
use_tdd_progressive: false
expression_weighting:
enabled: true
min_weight: 0.5
# Adaptive loss weighting
use_adaptive_weighting: true
adaptive_update_freq: 100 # Update weights every N steps # Minimum weight for least important dimensions
use_sync_loss: true # SyncNet is for evaluation only (per VASA-1 paper)
use_synchformer: true # Use Synchformer instead of SyncNet for sync loss
use_verification: false # Disabled for faster training
lambda_verification: 0
# =============================================================================
# TRAINING STAGE PROGRESSION TABLE
# =============================================================================
# | Epoch | Stage | Unlocks | Key Losses | Target Values |
# |-------|-------|-------------------|-------------------------------|---------------|
# | 0-4 | 1 | Foundation | reconstruction: 2.0 | < 0.1 |
# | | | | dynamics: 1.0, pose: 5.0 | |
# | | | | audio_lip: 10.0 (mouth motion)| |
# |-------|-------|-------------------|-------------------------------|---------------|
# | 5-14 | 2 | Blinking | blink: 0.1→2.0 (ramps) | < 0.15 |
# | | | | expression_l1: 1.5 | |
# |-------|-------|-------------------|-------------------------------|---------------|
# | 15-24 | 3 | Eye Gaze | gaze_direction: 0.1→1.5 | < 0.2 |
# | | | | nonlip: 0.1 | |
# |-------|-------|-------------------|-------------------------------|---------------|
# | 25-34 | 4 | Head Pose | head_distance: 0.1→1.0 | < 0.25 |
# | | | | identity: 0.3 | |
# |-------|-------|-------------------|-------------------------------|---------------|
# | 35-44 | 5 | Emotion | emotion: 0.0→0.8 | < 0.3 |
# | | | | perceptual: 0.5 | |
# |-------|-------|-------------------|-------------------------------|---------------|
# | 45+ | 6 | Lip Sync (Fine) | lips: 2.0 (static, high) | Refined |
# | | | | audio_lip: 10.0 (maintained) | |
# | | | | temporal: 0.01 (low variance)| |
# =============================================================================
# Foundation losses (always active)
lambda_reconstruction: 2.0 # Increased for better reconstruction
lambda_dynamics: 10.0 # INCREASED to 10.0 to force expression matching
lambda_expression_cosine: 10.0 # NEW: Direct GT expression matching via cosine similarity
lambda_pose: 5.0 # Increased from 1.0 to enforce better head pose learning
# SRT losses - Scale, Rotation, Translation components
lambda_scale: 1.0 # Scale prediction loss weight
lambda_rotation: 1.0 # Rotation prediction loss weight
lambda_translation: 1.0 # Translation prediction loss weight
# Progressive losses (unlock based on thresholds)
# Stage 2: Blinking (unlocks at epoch 5 if reconstruction < 0.1)
lambda_blink: 2 # Will ramp to 2.0
# Stage 3: Eye Gaze (unlocks at epoch 15 if blink < 0.15)
lambda_gaze_direction: 2 # DISABLED until TDD converges (was 0.1)
# Stage 4: Head Pose (unlocks at epoch 25 if gaze < 0.2)
lambda_head_distance: 0.0 # DISABLED until TDD converges (was 0.1)
# Stage 5: Emotion (unlocks at epoch 35 if head pose < 0.25)
lambda_emotion: 2.0 # DISABLED until TDD converges
# Stage 6: Lip Sync (unlocks at epoch 45 if emotion < 0.3)
lambda_lips: 2.0 # DISABLED to test mouth_perceptual in isolation
lambda_sync: 0.5 # REDUCED from 2.0 - Synchformer designed for evaluation, not training gradients
# Other losses
lambda_nonlip: 0.0 # DISABLED until TDD converges (was 0.1)
lambda_speed: 0.1 # DISABLED until TDD converges (was 0.1)
lambda_control: 1.0 # DISABLED until TDD converges - focus on theta+expression first
lambda_cfg: 0.5
lambda_perceptual: 1 # Increased perceptual loss for better quality
lambda_mouth_perceptual: 100 # Mouth-focused LPIPS loss (TalkVid style) - 100x weight on mouth region
lambda_temporal: 0.0 # DISABLED - temporal smoothness prevents expression variation
use_consistency: true # Enable VASA consistency loss
use_identity: true # Enable face identity preservation
lambda_consistency: 0.0 # Weight for consistency loss
lambda_identity: 0.3 # Weight for identity loss loss
# NEW VASA-1 inspired losses
use_score_matching: true # Use VASA-1 style clean signal prediction (not noise)
lambda_consist: 0.0 # DISABLED - expensive, not needed for single-identity overfitting (was 0.5)
lambda_cross_id: 0.0 # Cross-identity similarity loss (not needed for overfitting)
lambda_velocity: 0.0 # DISABLED to allow maximum expression variance
lambda_smoothness: 0.0 # DISABLED to allow maximum expression variance
lambda_audio_lip: 3.0 # REDUCED from 5.0 to 3.0 - balance with lip landmark loss (2.0)
lambda_mouth_openness: 10.0 # Direct mouth openness to audio energy supervision weight
lambda_aux_phoneme: 1.0 # INCREASED from 0.2 - Phoneme prediction (self-supervised) with class weighting to prevent mode collapse
lambda_aux_au: 1.0 # Action Unit prediction (self-supervised) for fine-grained facial expression control
lambda_aux_au_frame: 0.5 # Frame-based AU loss (50 frames per window) - extracts AUs from generated frames
lambda_landmark: 1.0 # AU→Landmark VAE loss weight (reconstruction + KL divergence)
# Expression-specific losses (CRITICAL FOR PREVENTING COLLAPSE)
lambda_expression_l1: 0.0 # DISABLED - conflicts with L2, too weak (was 0.1)
lambda_expression_variance: 0.1 # INCREASED to force variance matching (was 5.0)
lambda_expression_temporal: 0.1 # INCREASED to force temporal variation (was 2.0)
lambda_audio_expr_coupling: 0.1 # INCREASED to tie expression to audio changes (was 2.0)
# EMO (Volumetric Avatar) Matching Loss - ensures VASA matches high-quality EMO output
use_emo_matching: false # DISABLED - EMO keyframes are static, causing expression collapse
lambda_emo_match: 0.0 # DISABLED to prevent static pull
emo_match_type: 'combined' # 'l1', 'l2', 'lpips', or 'combined'
emo_cache_dir: 'cache_emo' # Directory for cached EMO outputs
emo_keyframes_per_window: 50 # Number of keyframes to compare per window
emo_match_start_epoch: 0 # Start EMO matching from epoch 0 for immediate quality
emo_min_valid_frames: 0.1 # LOWERED for overfitting: Skip EMO matching when <10% frames are valid (was 0.5)
# Disentanglement loss configuration
disentangle_num_pairs: 3 # Number of frame pairs to sample for l_consist
disentangle_sampling: 'random' # Sampling strategy: 'random', 'uniform', or 'extremes'
disentangle_compute_freq: 1 # Compute disentanglement losses every N steps (1=always)
disentangle_vis_freq: 100 # Visualize disentanglement every N steps
cross_id_num_frames: 3 # Number of frames to use for identity comparison
cross_id_normalize: true # Normalize frames before identity extraction
# Warping field loss weights
lambda_warp: 2.0 # Main warp reconstruction loss weight
lambda_warp_smooth: 0.0 # DISABLED - redundant with TV loss (was 0.01)
lambda_warp_temporal: 0.0 # DISABLED - redundant for overfitting, target is smooth (was 0.1)
lambda_source_theta: 0.5 # Source theta warp loss weight
lambda_warp_magnitude: 5.0 # UV warp magnitude matching (prevents collapse to zero)
# Frame generation for disentanglement losses and visualization
use_sparse_frames: false # If true, only generate first and last frame (saves 90% memory)
enable_frame_generation: true # DISABLED - control losses not needed until TDD (expression DB) converges
save_debug_frames: false
warmup_epochs: 0 # No warmup for overfitting
# Loss scheduling
schedule:
enabled: false # No scheduling for overfitting
start_epoch: 5
rampup_epochs: 10
blink:
start_weight: 0.0
final_weight: 0.0
# Flow-DPO loss parameters (VideoReward framework - Liu et al., 2025)
use_flow_dpo: false # DISABLED - Flow-DPO removed from codebase
lambda_flow_dpo: 0.0 # DISABLED - Flow-DPO removed from codebase
flow_dpo_start_epoch: 0 # When to start using Flow-DPO (after dynamics stage)
flow_dim: 140 # Dimension for velocity flows (12 theta + 128 expression = 140)
flow_noise_level: 0.1 # Noise level for generating dispreferred samples (10% of signal)
flow_beta_scale: 1.0 # Scale factor for beta_t in regret computation
flow_update_ref_freq: 100 # Update reference model every N epochs (frozen copy)
# Inference parameters
inference:
batch_size: 1
cfg_scale: 0.5 # Global scale (legacy, use cfg_scales below for per-condition)
# Per-condition CFG scales for inference (can override train values)
cfg_scales:
audio: 0.5 # Paper default: 0.5
gaze: 1.0 # Paper default: 1.0
head_distance: 0.8
emotion: 0.5
eta: 0.5 # Increased from 0.1 for more stochasticity to prevent static output
output_size: [512, 512]
num_inference_steps: 50 # Number of DDIM steps
temperature: 1.0 # Sampling temperature
top_k: 50 # Top-k sampling parameter
top_p: 0.9 # Nucleus sampling parameter
use_window_cache: false # Disabled to test if caching causes static output
# Dataset parameters - SINGLE VIDEO for overfitting
dataset:
sequence_length: 50
frame_size: [512, 512]
preextract_audio: true
cache_audio: true
hop_length: 10
max_videos: 100 # REDUCED from 1000 to 50 - use only the best quality videos
max_sequence_samples: 1000 # Maximum samples per epoch
use_identity_image: true # UNSTABLE - Use high-quality identity image instead of video frame
identity_image_path: "nemo/data/IMG_1.png" # Path to high-quality identity image
use_identity_theta: false # Use theta from identity image for all frames (no theta prediction/loss)
# Frame caching options (MD5-indexed disk storage)
cache_frames_to_disk: true # Save original frames to disk (default: true)
cache_emo_frames_to_disk: true # Save EMO frames to disk (default: true)
frame_format: 'png' # Image format for cached frames ('png' or 'jpg')
# Expression database settings
auto_rebuild_expression_db: true # Auto-rebuild expression DB after preprocessing
expression_db_frame_stride: 1 # Sample every Nth frame for DB (1=all frames, 5=every 5th)
# EMO generation settings
generate_emo_frames: true # DISABLED - not needed until TDD converges (was true)
emo_identity_path: "nemo/data/IMG_1.png" # Same identity image for EMO generation
emo_keyframes_per_window: 50 # DISABLED - skip EMO generation (was 5)
use_single_bucket: true # Use single-bucket caching for all windows
# Quality filtering
min_valid_face_ratio: 0.8 # Filter out videos with <80% valid face detections
augmentation:
enabled: true # ENABLED for better generalization
flip_prob: 0.2 # 20% chance of horizontal flip
temporal_crop_prob: 0.0
intensity_scale: 0.0
temporal_mask_prob: 0.0
# Model checkpoint paths
paths:
checkpoint_dir: "checkpoints_overfit"
syncnet_path: "pretrained/syncnet.pth"
emotion_model: "pretrained/emonet.pth"
volumetric_model: "nemo/logs/Retrain_with_17_V1_New_rand_MM_SEC_4_drop_02_stm_10_CV_05_1_1/checkpoints/328_model.pth"
volumetric_config: "nemo/models/stage_1/volumetric_avatar/va.yaml"
data_dir: "data"
video_folder: "s1"
cache_dir: "cache_per_video" # Using per-video cache structure
# Hardware/environment settings
device: "cuda"
seed: 42
num_workers: 8 # Multi-threading for data loading
debug: false
wandb:
project: "vasa-overfitting"
enabled: true
# Additional components
face_analysis:
face_detector: "retinaface"
face_parser: "rtnet"
emotion_recognizer: "hsemotion"
audio:
sample_rate: 16000
feature_type: "wav2vec2"
feature_dim: 768
normalize_audio: true
# Visualization settings
vis:
save_videos: true
save_frames: false
vis_freq: 50 # More frequent visualization
num_vis_samples: 4
# Logging configuration
logging:
level: "INFO"
save_frequency: 10 # More frequent logging
metrics:
save_predictions: true
compute_fid: false # Disabled for speed
compute_kid: false # Disabled for speed
# Loss visualization (ASCII graphs showing trend toward target)
loss_visualization:
enabled: true # Show loss graphs in summary
history_size: 50 # Number of samples to track
graph_width: 50 # Graph width in characters
graph_height: 6 # Graph height in lines
show_freq: 50 # Show visualization every N steps (more frequent for overfitting)
# Validation
validation:
enabled: true # Enable validation to track lip-sync metrics
frequency: 5 # Validate every 5 epochs
num_samples: 4 # Fewer samples for speed (was 10, now 4)
max_batches: 2 # Process at most N batches during validation (new)
metrics: ["reconstruction_loss", "lip_sync", "mel_spectrogram_sync", "audio_projection", "temporal_alignment"]
# Full metric list (enable when needed): ["lip_sync", "id_similarity", "motion_naturalness", "reconstruction_loss"]
skip_expensive_metrics: true # Skip compute-heavy metrics like id_similarity
# Lip-sync specific validation settings
lipsync:
compute_mel_correlation: true # Compute mel spectrogram correlation with lip motion
compute_audio_projection: true # Measure audio-to-visual projection quality
compute_temporal_alignment: true # Detect sync lag using cross-correlation
mel_spec_n_mels: 128 # Number of mel frequency bins
mel_spec_n_fft: 1024 # FFT window size
temporal_lag_frames: 5 # Check +/- N frames for temporal alignment
# Distributed training
distributed:
enabled: false # Single GPU for overfitting
backend: "nccl"
find_unused_parameters: false
gradient_as_bucket_view: true
# Optimization settings
optimization:
compile_model: false # Whether to use torch.compile
channels_last: true # Use channels_last memory format
gradient_checkpointing: false # Disabled for speed
sync_batchnorm: false
amp: true