Skip to content

Commit 0ae8c6b

Browse files
committed
Update AltDiffusion
Signed-off-by: ldwang <ldwang@baai.ac.cn>
1 parent a15df1d commit 0ae8c6b

File tree

19 files changed

+14
-8
lines changed

19 files changed

+14
-8
lines changed

examples/AltDiffusion/dreambooth.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pathlib import Path
1212

1313
import torch
14+
from torch.cuda.amp import autocast as autocast
1415
from torch.utils.data import Dataset
1516
from PIL import Image
1617
from torchvision import transforms
@@ -34,7 +35,7 @@
3435
train_text_encoder = False
3536
train_only_unet = True
3637

37-
num_train_epochs = 500
38+
num_train_epochs = 10
3839
batch_size = 4
3940
learning_rate = 5e-6
4041
adam_beta1 = 0.9
@@ -197,20 +198,23 @@ def collate_fn(examples):
197198
if with_prior_preservation:
198199
x, x_prior = torch.chunk(x, 2, dim=0)
199200
c, c_prior = torch.chunk(c, 2, dim=0)
200-
loss, _ = model(x, c)
201+
with autocast():
202+
loss, _ = model(x, c)
201203
prior_loss, _ = model(x_prior, c_prior)
202204
loss = loss + prior_loss_weight * prior_loss
203205
else:
204-
loss, _ = model(x, c)
206+
with autocast():
207+
loss, _ = model(x, c)
205208

206209
print('*'*20, "loss=", str(loss.detach().item()))
207210

208-
loss.backward()
209-
optimizer.step()
210-
optimizer.zero_grad()
211+
with autocast():
212+
loss.backward()
213+
optimizer.step()
214+
optimizer.zero_grad()
211215

212216
## mkdir ./checkpoints/DreamBooth and copy ./checkpoints/AltDiffusion to ./checkpoints/DreamBooth/AltDiffusion
213217
## overwrite model.ckpt for latter usage
214-
chekpoint_path = './checkpoints/DreamBooth/AltDiffusion/model.ckpt'
218+
chekpoint_path = './checkpoints/AltDiffusion/dreambooth_model.ckpt'
215219
torch.save(model.state_dict(), chekpoint_path)
216220

273 KB
Loading
82.9 KB
Loading
346 KB
Loading
40.1 KB
Loading
324 KB
Loading
58.2 KB
Loading
50.2 KB
Loading
146 KB
Loading
134 KB
Loading

0 commit comments

Comments
 (0)