-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
115 lines (84 loc) · 2.87 KB
/
main.py
File metadata and controls
115 lines (84 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torchvision
import torchvision.transforms as T
from dataclasses import dataclass
from cgan.model import Generator, Critic
from cgan.utils import calculate_grad_penalty
@dataclass
class Config:
epochs = 1
batch_size = 64
lr_gen = 2e-4
lr_critic = 1e-4
critic_iter = 2
lambda_gp = 10
img_size = (32, 32)
num_classes = 10
z_dim = 100
base_features = 128
num_blocks = 3
img_channels = 1
config = Config()
transform = T.Compose(
[T.Resize(config.img_size), T.ToTensor(), T.Normalize([0.5], [0.5])]
)
train_dataset = torchvision.datasets.MNIST(
root="./data", train=True, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=config.batch_size, shuffle=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator(
z_dim=config.z_dim,
g_channels=config.base_features,
out_channels=config.img_channels,
num_blocks=config.num_blocks,
num_classes=config.num_classes,
).to(device)
critic = Critic(
in_channels=config.img_channels,
d_channels=config.base_features,
num_blocks=config.num_blocks,
num_classes=config.num_classes,
img_size=config.img_size,
).to(device)
optim_gen = torch.optim.Adam(
params=gen.parameters(), lr=config.lr_gen, betas=(0.0, 0.9)
)
optim_critic = torch.optim.Adam(
params=critic.parameters(), lr=config.lr_critic, betas=(0.0, 0.9)
)
gen_losses = []
critic_losses = []
for epoch in range(config.epochs):
for batch_idx, (real, labels) in enumerate(train_loader):
labels = labels.to(device)
real = real.to(device)
batch_size = real.size(0)
critic_loss_batch = 0
for _ in range(config.critic_iter):
noise = torch.randn((batch_size, config.z_dim, 1, 1)).to(device)
fake = gen(noise, labels).detach()
C_real = critic(real, labels)
C_fake = critic(fake, labels)
gp = calculate_grad_penalty(critic, real, fake, labels, device)
C_loss = torch.mean(C_fake) - torch.mean(C_real) + config.lambda_gp * gp
critic_loss_batch += C_loss.item()
optim_critic.zero_grad()
C_loss.backward()
optim_critic.step()
noise = torch.randn((batch_size, config.z_dim, 1, 1)).to(device)
fake = gen(noise, labels)
G_loss = -torch.mean(critic(fake, labels))
optim_gen.zero_grad()
G_loss.backward()
optim_gen.step()
gen_losses.append(G_loss.item())
critic_losses.append(critic_loss_batch / config.critic_iter)
if batch_idx % 100 == 0:
print(
f"[epoch {epoch+1}/{config.epochs}, batch {batch_idx}] "
f"gen loss: {G_loss.item():.4f}, critic loss: {critic_loss_batch/config.critic_iter:.4f}"
)
print(f"epoch {epoch+1} completed")