-
Notifications
You must be signed in to change notification settings - Fork 98
Expand file tree
/
Copy pathmain.py
More file actions
103 lines (88 loc) · 3.47 KB
/
main.py
File metadata and controls
103 lines (88 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import asyncio
import sys
from apps.grpo.main import main as grpo_main
from forge.controller.launcher import (
JOB_NAME_KEY,
LAUNCHER_KEY,
MastLauncher,
mount_mnt_directory,
)
from forge.controller.provisioner import init_provisioner
from forge.types import (
Launcher,
LauncherConfig,
ProcessConfig,
ProvisionerConfig,
ServiceConfig,
)
from forge.util.config import parse
from omegaconf import DictConfig
DEFAULT_CHECKPOINT_FOLDER_KEY = "checkpoint_folder"
DEFAULT_CHECKPOINT_FOLDER = "/mnt/wsfuse/teamforge/forge_runs/"
async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None):
"""Main module for launching mast jobs for GRPO training.
Args:
cfg: Configuration dictionary
mode: "detached" (default) launches MAST job with client in MAST,
"remote" runs training directly (used when client runs in MAST)
extra_args: Additional CLI arguments to pass through to the client
"""
if cfg.get(LAUNCHER_KEY, Launcher.MAST.value) != Launcher.MAST.value:
raise ValueError("Launcher must be MAST.")
# Job name should already be set from CLI args in __main__ section
# No need to modify it further here
if cfg.get(JOB_NAME_KEY, None) is None:
raise ValueError("Job name is required but not provided")
launcher_config = LauncherConfig(
launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.MAST.value)),
job_name=cfg.get(JOB_NAME_KEY, None),
services={k: ServiceConfig(**v) for k, v in cfg.services.items()},
actors={k: ProcessConfig(**v) for k, v in cfg.actors.items()},
)
if mode == "detached":
# In detached mode, just launch the MAST job with client role included
launcher = MastLauncher(
launcher_config,
detached=True,
extra_args=extra_args or [],
)
await launcher.launch_mast_job()
print(f"MAST job {launcher.job_name} launched successfully with client role.")
print("The client is running inside MAST and will execute the training.")
else:
# In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training
mount_mnt_directory("/mnt/wsfuse")
await init_provisioner(ProvisionerConfig(launcher_config=launcher_config))
await grpo_main(cfg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
type=str,
default="detached",
choices=["detached", "remote"],
help="Run mode: 'detached' for launching MAST job with client in MAST, 'remote' for running training directly",
)
parser.add_argument(
"--job-name",
type=str,
default=None,
help="MAST job name (required - generated by launch.sh)",
)
args, remaining = parser.parse_known_args()
# Replace sys.argv with remaining args so @parse can work
sys.argv = [sys.argv[0]] + remaining
@parse
def _main(cfg):
# Override job name from CLI
if args.job_name:
cfg[JOB_NAME_KEY] = args.job_name
print(f"Using job name: {args.job_name}")
asyncio.run(main(cfg, mode=args.mode, extra_args=remaining))
_main() # @parse grabs the cfg from CLI