-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmain.py
More file actions
51 lines (43 loc) · 2.24 KB
/
main.py
File metadata and controls
51 lines (43 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from config import images_paths, ims_config, combined_folder, overwrite, mask_blur_size
from segmentation.segment_image import transfer_styles
from style_transfer.style_transfer import StyleTransferModel
import cv2
import numpy as np
from pathlib import Path
if __name__ == "__main__":
# Get all the masks
masks = {}
for im_path in images_paths:
masks[im_path] = {}
for seg_model in ims_config[im_path.name]["seg_models"]:
masks[im_path][seg_model] = transfer_styles(im_path, seg_model)
# Get all the styles
style_model = StyleTransferModel(images_paths, ims_config, overwrite=overwrite)
styles = style_model.run()
# Combine the two
for im_path in images_paths:
for i, seg_model in enumerate(ims_config[im_path.name]["seg_models"]):
for style in ims_config[im_path.name]["styles"]:
# Get the data for this image, style and model
seg_class = ims_config[im_path.name]["class"][i]
mask = masks[im_path][seg_model].astype("uint8")
stylized = cv2.cvtColor(styles[im_path][style], cv2.COLOR_RGB2BGR)
# Apply mask and get final image
original = cv2.imread(im_path.as_posix())
original = cv2.resize(original, stylized.shape[:2][::-1])
mask = cv2.resize(mask, stylized.shape[:2][::-1])
mask = (mask == seg_class).astype("uint8")
mask = cv2.blur(mask * 255, (mask_blur_size, mask_blur_size)) / 255
mask = np.expand_dims(mask, 2)
mask = np.repeat(mask, 3, axis=2)
output = (original.astype(float) * (1 - mask) + stylized.astype(float) * mask).astype("uint8")
impath = combined_folder / (im_path.stem + "_" + seg_model + "_" + Path(style).stem + ".png")
cv2.imwrite(impath.as_posix(), output)
print(f"\nSaved final image to {impath}")
# Show outputs
cv2.imshow("Original image", original)
cv2.imshow("Stylized image", stylized)
cv2.imshow("Mask", (mask * 255).astype("uint8")[:, :, 0])
cv2.imshow("Final image", output)
cv2.waitKey()
print("\n***** DONE *****")