Skip to content

Giventicket/SSIMBaD-Sigma-Scaling-with-SSIM-Guided-Balanced-Diffusion-for-AnimeFace-Colorization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

32 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

SSIMBaD: Sigma Scaling with SSIM-Guided Balanced Diffusion for AnimeFace Colorization

Official PyTorch implementation of "Sigma Scaling with SSIM-Guided Balanced Diffusion for AnimeFace Colorization"

ssimbad

1. Overview

SSIMBaD introduces a novel diffusion-based framework for automatic colorization of anime-style facial sketches. Unlike prior DDPM/EDM-based methods that rely on handcrafted or fixed noise schedules, SSIMBaD leverages a perceptual noise schedule grounded in SSIM-aligned sigma-space scaling. This design enforces uniform perceptual degradation throughout the diffusion process, improving both structural fidelity and stylistic accuracy in the generated outputs.

The following table compares baseline models and our proposed SSIMBaD framework under both same-reference and cross-reference settings. Metrics include PSNR (higher is better), MS-SSIM (higher is better), and FID (lower is better).

Method Training PSNR โ†‘ (Same / Cross) MS-SSIM โ†‘ (Same / Cross) FID โ†“ (Same / Cross)
SCFT [Lee2020] 300 epochs 17.17 / 15.47 0.7833 / 0.7627 43.98 / 45.18
AnimeDiffusion (pretrained) [Cao2024] 300 epochs 11.39 / 11.39 0.6748 / 0.6721 46.96 / 46.72
AnimeDiffusion (finetuned) [Cao2024] 300 + 10 epochs 13.32 / 12.52 0.7001 / 0.5683 135.12 / 139.13
SSIMBaD (w/o trajectory refinement) 300 epochs 15.15 / 13.04 0.7115 / 0.6736 53.33 / 55.18
SSIMBaD (w/ trajectory refinement) ๐Ÿ† 300 + 10 epochs 18.92 / 15.84 0.8512 / 0.8207 34.98 / 37.10

This repository includes:

  • ๐Ÿง  Pretraining with classifier-free guidance and structural reconstruction loss
  • ๐ŸŽฏ Finetuning with perceptual objectives and SSIM-guided trajectory refinement
  • ๐Ÿ“ˆ Perceptual noise schedule design based on SSIM curve fitting
  • ๐Ÿงช Full evaluation pipeline for same-reference and cross-reference scenarios

2. Key Idea

The quality of diffusion-based generation is highly sensitive to how noise levels are scheduled over time.

Existing models like DDPM and EDM use different schedules for training and inference (e.g., log(ฯƒ) vs. ฯƒ1/ฯ), often leading to perceptual mismatches that degrade visual consistency.

To resolve this, we introduce a shared transformation ฯ•: โ„+ โ†’ โ„ used consistently in both training and generation.
This transformation maps the raw noise scale ฯƒ to a perceptual difficulty axis, allowing uniform degradation in image quality over time.

We select the optimal transformation ฯ•* by maximizing the linearity of SSIM degradation over a candidate set ฮฆ:

์Šคํฌ๋ฆฐ์ƒท 2025-05-15 ์˜คํ›„ 5 23 10

This function achieves the best perceptual alignment, leading to smooth and balanced degradation curves.

๐Ÿงžโ€โ™‚๏ธ Think of ฯ•*(ฯƒ) as a perceptual "magic carpet ride" โ€” smooth, stable, and optimally guided by structural similarity.


3. Folder Overview

โ”œโ”€โ”€ pretrain.py                 # SSIMBaD training (EDM + ฯ†*(ฯƒ))
โ”œโ”€โ”€ finetune.py                 # Trajectory refinement stage
โ”œโ”€โ”€ SSIMBaD_pretrain.py # Baseline reproduction (vanilla EDM schedule)
โ”œโ”€โ”€ SSIMBaD_finetune.py # Baseline finetuning (MSE-based)
โ”œโ”€โ”€ evaluate_*.py              # FID / PSNR / SSIM evaluation
โ”œโ”€โ”€ optimal_phi.py             # ฯ†*(ฯƒ) search via SSIM Rยฒ maximization
โ”œโ”€โ”€ models/                    # Diffusion & U-Net architectures
โ”œโ”€โ”€ utils/                     # XDoG, TPS warp, logger, path utils
โ””โ”€โ”€ requirements.txt

4. Noise Schedule Analysis

We visualize how SSIM degrades across diffusion timesteps for various noise schedules:

SSIM Degradation Curves

DDPM EDM SSIMBaD

Corresponding Noisy Image Grids

DDPM EDM SSIMBaD

5. Installation

conda create -n ssimbad python=3.9
conda activate ssimbad
pip install -r requirements.txt

6. Dataset

We evaluate our method on a benchmark dataset introduced by Cao et al. (2024), specifically curated for reference-guided anime face colorization.

  • Dataset: Danbooru Anime Face Dataset
  • Train/Test Split: 31,696 training pairs and 579 test samples
  • Resolution: All images are resized to 256ร—256 pixels
  • Each sample includes:
    • I_gt: Ground-truth RGB image
    • I_sketch: Corresponding edge-based sketch, generated using the XDoG filter [Winnemรถller et al. 2012]
    • I_ref: A reference image providing color and style cues

The dataset is evaluated under two conditions:

  • Same-Reference Setting:
    The reference image is a spatially perturbed version of the ground-truth with the same underlying structure as I_sketch.

  • Cross-Reference Setting:
    The reference image is randomly sampled from other identities, introducing variation in color palette and facial attributes.

This dual evaluation setup allows us to measure both:

  1. Reconstruction fidelity under ideal alignment, and
  2. Generalization performance under cross-domain appearance shifts.

Note: During preprocessing, the reference image is warped from the ground truth using TPS (Thin Plate Spline) with random rotation and deformation, simulating natural variation.


7. Pretraining

The pretraining stage optimizes the base diffusion model using MSE loss between predicted and ground-truth RGB images, with sketch and reference inputs. It forms the foundation for perceptual finetuning.

python pretrain.py \
    --do_train True \
    --epochs 300 \
    --train_batch_size 32 \
    --inference_time_step 500 \
    --train_reference_path /data/Anime/train_data/reference/ \
    --train_condition_path /data/Anime/train_data/sketch/ \
    --gpus 0, 1

This is NOT a generic MSE finetuning like SSIMBaD. It optimizes the reverse trajectory using perceptual noise scaling.

8. Finetuning (trajectory refinement)

This project supports perceptual finetuning using pre-trained diffusion weights.
You can specify the strategy (e.g., lpips-vgg, clip, or mse) and resume from a checkpoint.

python finetune.py \
    --do_train True \
    --resume_from_checkpoint /path/to/checkpoint.ckpt \
    --strategy mse \
    --epochs 10 \
    --finetuning_inference_time_step 50 \
    --train_reference_path /data/Anime/train_data/reference/ \
    --train_condition_path /data/Anime/train_data/sketch/ \
    --gpus 0 1

9. Inference

Model evaluation is supported through the existing training scripts: pretrain.py for evaluating pretrained models, and finetune.py for evaluating finetuned ones. While testing is currently embedded within these scripts, we plan to release a dedicated and streamlined test.py interface in the future for improved usability and clarity.

๐Ÿ”ง Inference Command Examples

โ–ถ Using pretrain.py (for pretrained checkpoints)

python3 pretrain.py \
  --do_train False \
  --do_test True \
  --checkpoint_path ./checkpoints/best.ckpt \
  --test_output_dir ./result_inference \
  --test_reference_path /data/Anime/test_data/reference/ \
  --test_condition_path /data/Anime/test_data/sketch/ \
  --test_gt_path /data/Anime/test_data/reference/ \
  --gpus 0 \

โ–ถ Using finetune.py (for fine-tuned checkpoints)

python3 finetune.py \
    --do_train False \
    --do_test True \
    --checkpoint_path /root/SSIMBaD/logs/lightning_logs/version_5/checkpoints/epoch=03-train_avg_loss=0.0667.ckpt \
    --test_output_dir ./result_inference \
    --test_reference_path /data/Anime/test_data/reference/ \
    --test_condition_path /data/Anime/test_data/sketch/ \
    --gpus 0 \
    --do_guiding False
python3 finetune.py \
  --do_train False \
  --do_test True \
  --checkpoint_path ./checkpoints/best.ckpt \
  --test_output_dir ./result_inference \
  --test_reference_path /data/Anime/test_data/reference/ \
  --test_condition_path /data/Anime/test_data/sketch/ \
  --test_gt_path /data/Anime/test_data/reference/ \
  --gpus 0 \
  --do_guiding True

๐Ÿง  Notes

  • Set --do_train False and --do_test True to run inference only.
  • Use --checkpoint_path to specify the model to evaluate.
  • --do_guiding toggles LPIPS-guided inference (True) vs. plain DDIM-style inference (False).
  • Results will be saved in the directory specified by --test_output_dir.

10. Implementation Details

We implement our training and evaluation pipeline using PyTorch Lightning, enabling modular, scalable, and reproducible experimentation.

  • Hardware: All experiments were conducted on a single node equipped with 2ร— NVIDIA H100 (80GB) GPUs.
  • Distributed Training: We use DDP (Distributed Data Parallel) via strategy="ddp_find_unused_parameters_true" for efficient multi-GPU training.
  • Mixed Precision: Enabled via precision="16-mixed" to accelerate training and reduce memory usage without sacrificing model quality.
  • Training Epochs: Pretraining was conducted for 300 epochs using MSE loss.
  • Inference Timesteps: Default forward process uses 500 steps unless specified otherwise.
  • Checkpointing: Top-3 checkpoints are saved based on training loss using:
    ModelCheckpoint(monitor="train_avg_loss", save_top_k=3, mode="min")

11. Evaluation

Quantitative evaluation of the model can be performed using the following scripts. All metrics support configurable directory paths and device settings via command-line arguments.


โ–ถ FID

To compute FID between the reference and generated images:

python3 evaluate_FID.py \
  --real_dir /data/Anime/test_data/reference \
  --generated_dir ./result_inference \
  --device cuda \
  --batch_size 50

โ–ถ PSNR

python3 evaluate_PSNR.py \
  --real_dir /data/Anime/test_data/reference \
  --generated_dir ./result_inference \
  --device cuda

โ–ถ MS-SSIM

python3 evaluate_SSIM.py \
  --real_dir /data/Anime/test_data/reference \
  --generated_dir ./result_inference \
  --device cuda

12. Reference Settings: Visual Comparison

The dataset includes two test scenarios to evaluate reconstruction fidelity and style generalization:

Same Reference Cross Reference

Left: Same-Reference ย ย ย ย  Right: Cross-Reference

About

We introduce SSIMBaD (Sigma Scaling with SSIM-Guided Balanced Diffusion), a sigma-space transformation that ensures linear alignment of perceptual degradation, as measured by structural similarity.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

โšก