-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscheduler.py
More file actions
executable file
·33 lines (29 loc) · 1.02 KB
/
scheduler.py
File metadata and controls
executable file
·33 lines (29 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Any
import torch
import yacs.config
from warmup_scheduler import GradualWarmupScheduler
def create_scheduler(config: yacs.config.CfgNode, optimizer: Any) -> Any:
if config.scheduler_type == 'multistep':
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=config.lr_milestones,
gamma=config.gamma)
elif config.scheduler_type == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.lr_stepsize,
gamma=config.gamma)
elif config.scheduler_type == 'exp':
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=config.gamma)
else:
raise ValueError()
if config.warmup > 0:
scheduler = GradualWarmupScheduler(
optimizer,
multiplier=1,
total_epoch=config.warmup,
after_scheduler=scheduler
)
return scheduler