-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
88 lines (73 loc) · 2.51 KB
/
train.py
File metadata and controls
88 lines (73 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os, sys
from omegaconf import DictConfig, OmegaConf
import logging
import torch
import hydra
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from cvap.util import seed_all_rng, setup_logger
from cvap.monitor import *
def _distributed_worker(local_rank, main_func, cfg, ddp):
assert torch.cuda.is_available(), "CUDA is not available"
global_rank = 0 + local_rank
try:
dist.init_process_group(
backend="NCCL",
init_method=cfg.dist_url,
world_size=cfg.num_gpus,
rank=global_rank,
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(cfg.dist_url))
raise e
dist.barrier()
torch.cuda.set_device(local_rank)
pg = dist.new_group(range(cfg.num_gpus))
device = torch.device('cuda', local_rank)
main_func(cfg, local_rank, ddp, pg, device, DDPMonitor)
def main(cfg, rank, ddp, pg, device, manager):
cfg.rank = rank
seed_all_rng(cfg.seed) # + rank)
output_dir = f"{cfg.alias_root}/{cfg.model_name}"
logger = setup_logger(
output_dir=output_dir, rank=rank, output=output_dir,
)
if cfg.verbose or not cfg.eval:
cfg_str = OmegaConf.to_yaml(cfg)
logger.info(f"\n\n{cfg_str}")
if cfg.blockprint:
# https://stackoverflow.com/a/8391735
sys.stdout = open(os.devnull, 'w')
ngpu = torch.cuda.device_count()
logger.info("World size: {}; rank: {}".format(ngpu, rank))
torch.backends.cudnn.benchmark=True
if isinstance(manager, str):
monitor = eval(manager)(cfg, logger.info, device)
else:
monitor = manager(cfg, logger.info, device)
monitor.learn()
@hydra.main(config_path="configs", config_name="default")
def train(cfg: DictConfig) -> None:
if cfg.mode == "dp":
cfg.rank = 0
torch.cuda.set_device(0)
main(cfg, 0, False, False, torch.device('cuda', 0), cfg.monitor)
elif cfg.mode == "ddp":
try:
mp.spawn(
_distributed_worker,
nprocs = cfg.num_gpus,
args = (main, cfg, False),
daemon = False,
)
except KeyboardInterrupt as e:
dist.destroy_process_group()
else:
cfg.rank = 0
torch.cuda.set_device(0)
main(cfg, 0, False, False, torch.device('cuda', 0), None)
if __name__ == "__main__":
train()