Skip to content

Commit 90a4e81

Browse files
author
kevinkhwu
committed
update training script
1 parent 8421cea commit 90a4e81

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

train.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def sample(self, batch_size: int, device: torch.device = None) -> torch.Tensor:
239239
t = 1.0 / (1.0 + torch.exp(-u)) * (t1 - t0) + t0
240240

241241
elif self.snr_type == SNRType.MIX:
242-
# Mix sampling: 30% lognorm + 70% clipped uniform
243242
u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device)
244243
t_lognorm = 1.0 / (1.0 + torch.exp(-u)) * (t1 - t0) + t0
245244

@@ -249,7 +248,6 @@ def sample(self, batch_size: int, device: torch.device = None) -> torch.Tensor:
249248
t1_clip = t1 - delta
250249
t_clip_uniform = torch.rand((batch_size,), device=device) * (t1_clip - t0_clip) + t0_clip
251250

252-
# Mix with 30% lognorm, 70% uniform
253251
mask = (torch.rand((batch_size,), device=device) > 0.3).float()
254252
t = mask * t_lognorm + (1 - mask) * t_clip_uniform
255253

@@ -584,8 +582,8 @@ def _build_optimizer(self):
584582
self.lr_scheduler = get_scheduler(
585583
"constant",
586584
optimizer=self.optimizer,
587-
num_warmup_steps=self.config.warmup_steps * self.world_size,
588-
num_training_steps=self.config.max_steps * self.world_size,
585+
num_warmup_steps=self.config.warmup_steps,
586+
num_training_steps=self.config.max_steps,
589587
)
590588

591589
if self.is_main_process:
@@ -990,14 +988,11 @@ def train(self, dataloader):
990988

991989
if (self.global_step + 1) % self.config.save_interval == 0:
992990
self.save_checkpoint(self.global_step + 1)
993-
if self.world_size > 1:
994-
dist.barrier()
995991

996992
self.global_step += 1
997993

998-
if self.is_main_process:
999-
self.save_checkpoint(self.global_step)
1000-
logger.info("Training completed!")
994+
self.save_checkpoint(self.global_step)
995+
logger.info("Training completed!")
1001996

1002997
if self.world_size > 1:
1003998
dist.barrier()

0 commit comments

Comments
 (0)