Skip to content

Latest commit

 

History

History
281 lines (203 loc) · 11.6 KB

File metadata and controls

281 lines (203 loc) · 11.6 KB

SSIMBaD: Sigma Scaling with SSIM-Guided Balanced Diffusion for AnimeFace Colorization

Official PyTorch implementation of "Sigma Scaling with SSIM-Guided Balanced Diffusion for AnimeFace Colorization"

ssimbad

1. Overview

SSIMBaD introduces a novel diffusion-based framework for automatic colorization of anime-style facial sketches. Unlike prior DDPM/EDM-based methods that rely on handcrafted or fixed noise schedules, SSIMBaD leverages a perceptual noise schedule grounded in SSIM-aligned sigma-space scaling. This design enforces uniform perceptual degradation throughout the diffusion process, improving both structural fidelity and stylistic accuracy in the generated outputs.

The following table compares baseline models and our proposed SSIMBaD framework under both same-reference and cross-reference settings. Metrics include PSNR (higher is better), MS-SSIM (higher is better), and FID (lower is better).

Method Training PSNR ↑ (Same / Cross) MS-SSIM ↑ (Same / Cross) FID ↓ (Same / Cross)
SCFT [Lee2020] 300 epochs 17.17 / 15.47 0.7833 / 0.7627 43.98 / 45.18
AnimeDiffusion (pretrained) [Cao2024] 300 epochs 11.39 / 11.39 0.6748 / 0.6721 46.96 / 46.72
AnimeDiffusion (finetuned) [Cao2024] 300 + 10 epochs 13.32 / 12.52 0.7001 / 0.5683 135.12 / 139.13
SSIMBaD (w/o trajectory refinement) 300 epochs 15.15 / 13.04 0.7115 / 0.6736 53.33 / 55.18
SSIMBaD (w/ trajectory refinement) 🏆 300 + 10 epochs 18.92 / 15.84 0.8512 / 0.8207 34.98 / 37.10

This repository includes:

  • 🧠 Pretraining with classifier-free guidance and structural reconstruction loss
  • 🎯 Finetuning with perceptual objectives and SSIM-guided trajectory refinement
  • 📈 Perceptual noise schedule design based on SSIM curve fitting
  • 🧪 Full evaluation pipeline for same-reference and cross-reference scenarios

2. Key Idea

The quality of diffusion-based generation is highly sensitive to how noise levels are scheduled over time.

Existing models like DDPM and EDM use different schedules for training and inference (e.g., log(σ) vs. σ1/ρ), often leading to perceptual mismatches that degrade visual consistency.

To resolve this, we introduce a shared transformation ϕ: ℝ+ → ℝ used consistently in both training and generation.
This transformation maps the raw noise scale σ to a perceptual difficulty axis, allowing uniform degradation in image quality over time.

We select the optimal transformation ϕ* by maximizing the linearity of SSIM degradation over a candidate set Φ:

스크린샷 2025-05-15 오후 5 23 10

This function achieves the best perceptual alignment, leading to smooth and balanced degradation curves.

🧞‍♂️ Think of ϕ*(σ) as a perceptual "magic carpet ride" — smooth, stable, and optimally guided by structural similarity.


3. Folder Overview

├── pretrain.py                 # SSIMBaD training (EDM + φ*(σ))
├── finetune.py                 # Trajectory refinement stage
├── SSIMBaD_pretrain.py # Baseline reproduction (vanilla EDM schedule)
├── SSIMBaD_finetune.py # Baseline finetuning (MSE-based)
├── evaluate_*.py              # FID / PSNR / SSIM evaluation
├── optimal_phi.py             # φ*(σ) search via SSIM R² maximization
├── models/                    # Diffusion & U-Net architectures
├── utils/                     # XDoG, TPS warp, logger, path utils
└── requirements.txt

4. Noise Schedule Analysis

We visualize how SSIM degrades across diffusion timesteps for various noise schedules:

SSIM Degradation Curves

DDPM EDM SSIMBaD

Corresponding Noisy Image Grids

DDPM EDM SSIMBaD

5. Installation

conda create -n ssimbad python=3.9
conda activate ssimbad
pip install -r requirements.txt

6. Dataset

We evaluate our method on a benchmark dataset introduced by Cao et al. (2024), specifically curated for reference-guided anime face colorization.

  • Dataset: Danbooru Anime Face Dataset
  • Train/Test Split: 31,696 training pairs and 579 test samples
  • Resolution: All images are resized to 256×256 pixels
  • Each sample includes:
    • I_gt: Ground-truth RGB image
    • I_sketch: Corresponding edge-based sketch, generated using the XDoG filter [Winnemöller et al. 2012]
    • I_ref: A reference image providing color and style cues

The dataset is evaluated under two conditions:

  • Same-Reference Setting:
    The reference image is a spatially perturbed version of the ground-truth with the same underlying structure as I_sketch.

  • Cross-Reference Setting:
    The reference image is randomly sampled from other identities, introducing variation in color palette and facial attributes.

This dual evaluation setup allows us to measure both:

  1. Reconstruction fidelity under ideal alignment, and
  2. Generalization performance under cross-domain appearance shifts.

Note: During preprocessing, the reference image is warped from the ground truth using TPS (Thin Plate Spline) with random rotation and deformation, simulating natural variation.


7. Pretraining

The pretraining stage optimizes the base diffusion model using MSE loss between predicted and ground-truth RGB images, with sketch and reference inputs. It forms the foundation for perceptual finetuning.

python pretrain.py \
    --do_train True \
    --epochs 300 \
    --train_batch_size 32 \
    --inference_time_step 500 \
    --train_reference_path /data/Anime/train_data/reference/ \
    --train_condition_path /data/Anime/train_data/sketch/ \
    --gpus 0, 1

This is NOT a generic MSE finetuning like SSIMBaD. It optimizes the reverse trajectory using perceptual noise scaling.

8. Finetuning (trajectory refinement)

This project supports perceptual finetuning using pre-trained diffusion weights.
You can specify the strategy (e.g., lpips-vgg, clip, or mse) and resume from a checkpoint.

python finetune.py \
    --do_train True \
    --resume_from_checkpoint /path/to/checkpoint.ckpt \
    --strategy mse \
    --epochs 10 \
    --finetuning_inference_time_step 50 \
    --train_reference_path /data/Anime/train_data/reference/ \
    --train_condition_path /data/Anime/train_data/sketch/ \
    --gpus 0 1

9. Inference

Model evaluation is supported through the existing training scripts: pretrain.py for evaluating pretrained models, and finetune.py for evaluating finetuned ones. While testing is currently embedded within these scripts, we plan to release a dedicated and streamlined test.py interface in the future for improved usability and clarity.

🔧 Inference Command Examples

▶ Using pretrain.py (for pretrained checkpoints)

python3 pretrain.py \
  --do_train False \
  --do_test True \
  --checkpoint_path ./checkpoints/best.ckpt \
  --test_output_dir ./result_inference \
  --test_reference_path /data/Anime/test_data/reference/ \
  --test_condition_path /data/Anime/test_data/sketch/ \
  --test_gt_path /data/Anime/test_data/reference/ \
  --gpus 0 \

▶ Using finetune.py (for fine-tuned checkpoints)

python3 finetune.py \
    --do_train False \
    --do_test True \
    --checkpoint_path /root/SSIMBaD/logs/lightning_logs/version_5/checkpoints/epoch=03-train_avg_loss=0.0667.ckpt \
    --test_output_dir ./result_inference \
    --test_reference_path /data/Anime/test_data/reference/ \
    --test_condition_path /data/Anime/test_data/sketch/ \
    --gpus 0 \
    --do_guiding False
python3 finetune.py \
  --do_train False \
  --do_test True \
  --checkpoint_path ./checkpoints/best.ckpt \
  --test_output_dir ./result_inference \
  --test_reference_path /data/Anime/test_data/reference/ \
  --test_condition_path /data/Anime/test_data/sketch/ \
  --test_gt_path /data/Anime/test_data/reference/ \
  --gpus 0 \
  --do_guiding True

🧠 Notes

  • Set --do_train False and --do_test True to run inference only.
  • Use --checkpoint_path to specify the model to evaluate.
  • --do_guiding toggles LPIPS-guided inference (True) vs. plain DDIM-style inference (False).
  • Results will be saved in the directory specified by --test_output_dir.

10. Implementation Details

We implement our training and evaluation pipeline using PyTorch Lightning, enabling modular, scalable, and reproducible experimentation.

  • Hardware: All experiments were conducted on a single node equipped with 2× NVIDIA H100 (80GB) GPUs.
  • Distributed Training: We use DDP (Distributed Data Parallel) via strategy="ddp_find_unused_parameters_true" for efficient multi-GPU training.
  • Mixed Precision: Enabled via precision="16-mixed" to accelerate training and reduce memory usage without sacrificing model quality.
  • Training Epochs: Pretraining was conducted for 300 epochs using MSE loss.
  • Inference Timesteps: Default forward process uses 500 steps unless specified otherwise.
  • Checkpointing: Top-3 checkpoints are saved based on training loss using:
    ModelCheckpoint(monitor="train_avg_loss", save_top_k=3, mode="min")

11. Evaluation

Quantitative evaluation of the model can be performed using the following scripts. All metrics support configurable directory paths and device settings via command-line arguments.


▶ FID

To compute FID between the reference and generated images:

python3 evaluate_FID.py \
  --real_dir /data/Anime/test_data/reference \
  --generated_dir ./result_inference \
  --device cuda \
  --batch_size 50

▶ PSNR

python3 evaluate_PSNR.py \
  --real_dir /data/Anime/test_data/reference \
  --generated_dir ./result_inference \
  --device cuda

▶ MS-SSIM

python3 evaluate_SSIM.py \
  --real_dir /data/Anime/test_data/reference \
  --generated_dir ./result_inference \
  --device cuda

12. Reference Settings: Visual Comparison

The dataset includes two test scenarios to evaluate reconstruction fidelity and style generalization:

Same Reference Cross Reference

Left: Same-Reference      Right: Cross-Reference