Official PyTorch implementation of "Sigma Scaling with SSIM-Guided Balanced Diffusion for AnimeFace Colorization"
SSIMBaD introduces a novel diffusion-based framework for automatic colorization of anime-style facial sketches. Unlike prior DDPM/EDM-based methods that rely on handcrafted or fixed noise schedules, SSIMBaD leverages a perceptual noise schedule grounded in SSIM-aligned sigma-space scaling. This design enforces uniform perceptual degradation throughout the diffusion process, improving both structural fidelity and stylistic accuracy in the generated outputs.
The following table compares baseline models and our proposed SSIMBaD framework under both same-reference and cross-reference settings. Metrics include PSNR (higher is better), MS-SSIM (higher is better), and FID (lower is better).
| Method | Training | PSNR โ (Same / Cross) | MS-SSIM โ (Same / Cross) | FID โ (Same / Cross) |
|---|---|---|---|---|
| SCFT [Lee2020] | 300 epochs | 17.17 / 15.47 | 0.7833 / 0.7627 | 43.98 / 45.18 |
| AnimeDiffusion (pretrained) [Cao2024] | 300 epochs | 11.39 / 11.39 | 0.6748 / 0.6721 | 46.96 / 46.72 |
| AnimeDiffusion (finetuned) [Cao2024] | 300 + 10 epochs | 13.32 / 12.52 | 0.7001 / 0.5683 | 135.12 / 139.13 |
| SSIMBaD (w/o trajectory refinement) | 300 epochs | 15.15 / 13.04 | 0.7115 / 0.6736 | 53.33 / 55.18 |
| SSIMBaD (w/ trajectory refinement) ๐ | 300 + 10 epochs | 18.92 / 15.84 | 0.8512 / 0.8207 | 34.98 / 37.10 |
This repository includes:
- ๐ง Pretraining with classifier-free guidance and structural reconstruction loss
- ๐ฏ Finetuning with perceptual objectives and SSIM-guided trajectory refinement
- ๐ Perceptual noise schedule design based on SSIM curve fitting
- ๐งช Full evaluation pipeline for same-reference and cross-reference scenarios
The quality of diffusion-based generation is highly sensitive to how noise levels are scheduled over time.
Existing models like DDPM and EDM use different schedules for training and inference (e.g., log(ฯ) vs. ฯ1/ฯ), often leading to perceptual mismatches that degrade visual consistency.
To resolve this, we introduce a shared transformation ฯ: โ+ โ โ used consistently in both training and generation.
This transformation maps the raw noise scale ฯ to a perceptual difficulty axis, allowing uniform degradation in image quality over time.
We select the optimal transformation ฯ* by maximizing the linearity of SSIM degradation over a candidate set ฮฆ:
This function achieves the best perceptual alignment, leading to smooth and balanced degradation curves.
๐งโโ๏ธ Think of ฯ*(ฯ) as a perceptual "magic carpet ride" โ smooth, stable, and optimally guided by structural similarity.
โโโ pretrain.py # SSIMBaD training (EDM + ฯ*(ฯ))
โโโ finetune.py # Trajectory refinement stage
โโโ SSIMBaD_pretrain.py # Baseline reproduction (vanilla EDM schedule)
โโโ SSIMBaD_finetune.py # Baseline finetuning (MSE-based)
โโโ evaluate_*.py # FID / PSNR / SSIM evaluation
โโโ optimal_phi.py # ฯ*(ฯ) search via SSIM Rยฒ maximization
โโโ models/ # Diffusion & U-Net architectures
โโโ utils/ # XDoG, TPS warp, logger, path utils
โโโ requirements.txt
We visualize how SSIM degrades across diffusion timesteps for various noise schedules:
| DDPM | EDM | SSIMBaD |
|---|---|---|
![]() |
![]() |
![]() |
| DDPM | EDM | SSIMBaD |
|---|---|---|
![]() |
![]() |
![]() |
conda create -n ssimbad python=3.9
conda activate ssimbad
pip install -r requirements.txtWe evaluate our method on a benchmark dataset introduced by Cao et al. (2024), specifically curated for reference-guided anime face colorization.
- Dataset: Danbooru Anime Face Dataset
- Train/Test Split: 31,696 training pairs and 579 test samples
- Resolution: All images are resized to 256ร256 pixels
- Each sample includes:
I_gt: Ground-truth RGB imageI_sketch: Corresponding edge-based sketch, generated using the XDoG filter [Winnemรถller et al. 2012]I_ref: A reference image providing color and style cues
The dataset is evaluated under two conditions:
-
Same-Reference Setting:
The reference image is a spatially perturbed version of the ground-truth with the same underlying structure asI_sketch. -
Cross-Reference Setting:
The reference image is randomly sampled from other identities, introducing variation in color palette and facial attributes.
This dual evaluation setup allows us to measure both:
- Reconstruction fidelity under ideal alignment, and
- Generalization performance under cross-domain appearance shifts.
Note: During preprocessing, the reference image is warped from the ground truth using TPS (Thin Plate Spline) with random rotation and deformation, simulating natural variation.
The pretraining stage optimizes the base diffusion model using MSE loss between predicted and ground-truth RGB images, with sketch and reference inputs. It forms the foundation for perceptual finetuning.
python pretrain.py \
--do_train True \
--epochs 300 \
--train_batch_size 32 \
--inference_time_step 500 \
--train_reference_path /data/Anime/train_data/reference/ \
--train_condition_path /data/Anime/train_data/sketch/ \
--gpus 0, 1This is NOT a generic MSE finetuning like SSIMBaD. It optimizes the reverse trajectory using perceptual noise scaling.
This project supports perceptual finetuning using pre-trained diffusion weights.
You can specify the strategy (e.g., lpips-vgg, clip, or mse) and resume from a checkpoint.
python finetune.py \
--do_train True \
--resume_from_checkpoint /path/to/checkpoint.ckpt \
--strategy mse \
--epochs 10 \
--finetuning_inference_time_step 50 \
--train_reference_path /data/Anime/train_data/reference/ \
--train_condition_path /data/Anime/train_data/sketch/ \
--gpus 0 1Model evaluation is supported through the existing training scripts: pretrain.py for evaluating pretrained models, and finetune.py for evaluating finetuned ones. While testing is currently embedded within these scripts, we plan to release a dedicated and streamlined test.py interface in the future for improved usability and clarity.
python3 pretrain.py \
--do_train False \
--do_test True \
--checkpoint_path ./checkpoints/best.ckpt \
--test_output_dir ./result_inference \
--test_reference_path /data/Anime/test_data/reference/ \
--test_condition_path /data/Anime/test_data/sketch/ \
--test_gt_path /data/Anime/test_data/reference/ \
--gpus 0 \python3 finetune.py \
--do_train False \
--do_test True \
--checkpoint_path /root/SSIMBaD/logs/lightning_logs/version_5/checkpoints/epoch=03-train_avg_loss=0.0667.ckpt \
--test_output_dir ./result_inference \
--test_reference_path /data/Anime/test_data/reference/ \
--test_condition_path /data/Anime/test_data/sketch/ \
--gpus 0 \
--do_guiding Falsepython3 finetune.py \
--do_train False \
--do_test True \
--checkpoint_path ./checkpoints/best.ckpt \
--test_output_dir ./result_inference \
--test_reference_path /data/Anime/test_data/reference/ \
--test_condition_path /data/Anime/test_data/sketch/ \
--test_gt_path /data/Anime/test_data/reference/ \
--gpus 0 \
--do_guiding True
- Set
--do_train Falseand--do_test Trueto run inference only. - Use
--checkpoint_pathto specify the model to evaluate. --do_guidingtoggles LPIPS-guided inference (True) vs. plain DDIM-style inference (False).- Results will be saved in the directory specified by
--test_output_dir.
We implement our training and evaluation pipeline using PyTorch Lightning, enabling modular, scalable, and reproducible experimentation.
- Hardware: All experiments were conducted on a single node equipped with 2ร NVIDIA H100 (80GB) GPUs.
- Distributed Training: We use DDP (Distributed Data Parallel) via
strategy="ddp_find_unused_parameters_true"for efficient multi-GPU training. - Mixed Precision: Enabled via
precision="16-mixed"to accelerate training and reduce memory usage without sacrificing model quality. - Training Epochs: Pretraining was conducted for 300 epochs using MSE loss.
- Inference Timesteps: Default forward process uses 500 steps unless specified otherwise.
- Checkpointing: Top-3 checkpoints are saved based on training loss using:
ModelCheckpoint(monitor="train_avg_loss", save_top_k=3, mode="min")
Quantitative evaluation of the model can be performed using the following scripts. All metrics support configurable directory paths and device settings via command-line arguments.
To compute FID between the reference and generated images:
python3 evaluate_FID.py \
--real_dir /data/Anime/test_data/reference \
--generated_dir ./result_inference \
--device cuda \
--batch_size 50python3 evaluate_PSNR.py \
--real_dir /data/Anime/test_data/reference \
--generated_dir ./result_inference \
--device cudapython3 evaluate_SSIM.py \
--real_dir /data/Anime/test_data/reference \
--generated_dir ./result_inference \
--device cudaThe dataset includes two test scenarios to evaluate reconstruction fidelity and style generalization:
Left: Same-Reference ย ย ย ย Right: Cross-Reference









