-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_SSIM.py
More file actions
60 lines (45 loc) · 2.42 KB
/
evaluate_SSIM.py
File metadata and controls
60 lines (45 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import cv2
import torch
import piq
import argparse
from torchvision import transforms
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate MS-SSIM between real and generated images.")
parser.add_argument("--real_dir", type=str, default = "/data/Anime/test_data/reference", help="Directory containing reference (real) images.")
parser.add_argument("--generated_dir", type=str, default = "./result_same_finetuned", help="Directory containing generated images.")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use for computation.")
return parser.parse_args()
def main():
args = parse_args()
device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu")
real_files = sorted([f for f in os.listdir(args.real_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
generated_files = sorted([f for f in os.listdir(args.generated_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
if len(real_files) != len(generated_files):
print(f"⚠️ File count mismatch: {len(real_files)} real vs {len(generated_files)} generated")
return
to_tensor = transforms.Compose([
transforms.ToTensor(),
])
ms_ssim_scores = []
for real_file, gen_file in tqdm(zip(real_files, generated_files), total=len(real_files)):
real_path = os.path.join(args.real_dir, real_file)
gen_path = os.path.join(args.generated_dir, gen_file)
img_real = cv2.imread(real_path)
img_gen = cv2.imread(gen_path)
if img_real is None or img_gen is None:
print(f"⚠️ Failed to load: {real_file} or {gen_file}")
continue
img_gen = cv2.resize(img_gen, (img_real.shape[1], img_real.shape[0]))
img_real = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB)
img_gen = cv2.cvtColor(img_gen, cv2.COLOR_BGR2RGB)
real_tensor = to_tensor(img_real).unsqueeze(0).to(device)
gen_tensor = to_tensor(img_gen).unsqueeze(0).to(device)
ms_ssim = piq.multi_scale_ssim(gen_tensor, real_tensor, data_range=1.0).item()
ms_ssim_scores.append(ms_ssim)
print(f"{real_file} vs {gen_file} ➔ MS-SSIM: {ms_ssim:.4f}")
avg_ms_ssim = sum(ms_ssim_scores) / len(ms_ssim_scores) if ms_ssim_scores else 0.0
print(f"\n📊 Mean MS-SSIM Score: {avg_ms_ssim:.4f}")
if __name__ == "__main__":
main()