-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfinetune.py
More file actions
119 lines (100 loc) · 4.33 KB
/
finetune.py
File metadata and controls
119 lines (100 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import yaml
import torch
from datetime import datetime
from torch.amp import GradScaler
from torch.optim import lr_scheduler
from src import FineTuner, PretrainedStableDiffusion
from src import count_parameters, plot_losses, get_data_loader, set_seed
if __name__ == "__main__":
now = datetime.now()
dt_string = now.strftime("%d%m%Y_%H%M")
with open("config/finetune.yaml", "r") as f:
config = yaml.safe_load(f)
# OS
os_config = config["os"]
set_seed(os_config["seed"])
# Dataset
dataset_config = config["dataset"]
train_csv_file = dataset_config["train_csv_file"]
train_image_folder = dataset_config["train_image_folder"]
val_csv_file = dataset_config["val_csv_file"]
val_image_folder = dataset_config["val_image_folder"]
batch_size = dataset_config["batch_size"]
resolution = dataset_config["resolution"]
train_loader = get_data_loader(
csv_files=train_csv_file,
image_folder=train_image_folder,
batch_size=batch_size,
resolution=resolution,
train=True)
val_loader = get_data_loader(
csv_files=val_csv_file,
image_folder=val_image_folder,
batch_size=batch_size,
resolution=resolution,
train=False)
# Model
model_config = config["model"]
sd_models_id = model_config["sd_models_id"]
step_dim = model_config["step_dim"]
emoji_model = PretrainedStableDiffusion(sd_models_id=sd_models_id,
step_dim=step_dim,
resolution=resolution,
visualize=False)
# Training
training_config = config["training"]
lr = float(training_config["learning_rate"])
eta_min = float(training_config["eta_min"])
num_epochs = int(training_config["num_epochs"])
optimizer = torch.optim.AdamW(emoji_model.diffusion.parameters(), lr=lr)
criterion = torch.nn.MSELoss()
scaler = GradScaler()
lrate_scheduler = lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=eta_min)
vae_params = count_parameters(emoji_model.vae)
diffusion_params = count_parameters(emoji_model.diffusion)
text_encoder_params = count_parameters(emoji_model.text_encoder)
print(f"VAE parameters: {vae_params}")
print(f"Diffusion parameters: {diffusion_params}")
print(f"Text Encoder parameters: {text_encoder_params}")
print(
f"Total parameters: {vae_params + diffusion_params + text_encoder_params}")
save_best = training_config["save_best"]
# Save experiment results
weights_folder = training_config["weights_folder"]
weight_name = training_config["weight_name"]
os.makedirs(weights_folder, exist_ok=True)
save_weight_folder = os.path.join(weights_folder, dt_string)
os.makedirs(save_weight_folder, exist_ok=True)
save_best_path = os.path.join(save_weight_folder, weight_name)
finetuner = FineTuner(emoji_model=emoji_model,
optimizer=optimizer,
lrate_scheduler=lrate_scheduler,
scaler=scaler,
criterion=criterion,
train_loader=train_loader,
val_loader=val_loader,
num_epochs=num_epochs,
save_best=save_best,
save_best_path=save_best_path)
finetuner.train()
experiments_folder = training_config["experiments_folder"]
os.makedirs(experiments_folder, exist_ok=True)
save_experiment_folder = os.path.join(experiments_folder, dt_string)
os.makedirs(save_experiment_folder, exist_ok=True)
losses = finetuner.losses
plot_losses(losses, save_experiment_folder)
with open(save_experiment_folder + "/losses.txt", "w") as f:
for loss in losses:
f.write(f"{loss}\n")
with open(save_experiment_folder + "/config.yaml", "w") as f:
yaml.dump(config, f)
# Generate images for evaluation
gen_images_folder_name = training_config["gen_images_folder_name"]
save_gen_images_folder = os.path.join(
save_experiment_folder, gen_images_folder_name)
os.makedirs(save_gen_images_folder, exist_ok=True)
finetuner.gen4eval(num_inference_steps=50,
gen_folder=save_gen_images_folder)
print("End of training")