@@ -239,7 +239,6 @@ def sample(self, batch_size: int, device: torch.device = None) -> torch.Tensor:
239239 t = 1.0 / (1.0 + torch .exp (- u )) * (t1 - t0 ) + t0
240240
241241 elif self .snr_type == SNRType .MIX :
242- # Mix sampling: 30% lognorm + 70% clipped uniform
243242 u = torch .normal (mean = 0.0 , std = 1.0 , size = (batch_size ,), device = device )
244243 t_lognorm = 1.0 / (1.0 + torch .exp (- u )) * (t1 - t0 ) + t0
245244
@@ -249,7 +248,6 @@ def sample(self, batch_size: int, device: torch.device = None) -> torch.Tensor:
249248 t1_clip = t1 - delta
250249 t_clip_uniform = torch .rand ((batch_size ,), device = device ) * (t1_clip - t0_clip ) + t0_clip
251250
252- # Mix with 30% lognorm, 70% uniform
253251 mask = (torch .rand ((batch_size ,), device = device ) > 0.3 ).float ()
254252 t = mask * t_lognorm + (1 - mask ) * t_clip_uniform
255253
@@ -584,8 +582,8 @@ def _build_optimizer(self):
584582 self .lr_scheduler = get_scheduler (
585583 "constant" ,
586584 optimizer = self .optimizer ,
587- num_warmup_steps = self .config .warmup_steps * self . world_size ,
588- num_training_steps = self .config .max_steps * self . world_size ,
585+ num_warmup_steps = self .config .warmup_steps ,
586+ num_training_steps = self .config .max_steps ,
589587 )
590588
591589 if self .is_main_process :
@@ -990,14 +988,11 @@ def train(self, dataloader):
990988
991989 if (self .global_step + 1 ) % self .config .save_interval == 0 :
992990 self .save_checkpoint (self .global_step + 1 )
993- if self .world_size > 1 :
994- dist .barrier ()
995991
996992 self .global_step += 1
997993
998- if self .is_main_process :
999- self .save_checkpoint (self .global_step )
1000- logger .info ("Training completed!" )
994+ self .save_checkpoint (self .global_step )
995+ logger .info ("Training completed!" )
1001996
1002997 if self .world_size > 1 :
1003998 dist .barrier ()
0 commit comments