|
6 | 6 |
|
7 | 7 | """Launcher specific logic (i.e. SLURM, k8s when supported, etc.)""" |
8 | 8 |
|
9 | | -import copy |
10 | | -import getpass |
11 | | -import os |
12 | | -import subprocess |
13 | 9 | import tempfile |
14 | | -import uuid |
15 | 10 | from typing import Any |
16 | 11 |
|
17 | 12 | import monarch |
18 | | -import torchx.specs as specs |
19 | | - |
| 13 | +from forge.controller.base import BaseLauncher |
20 | 14 | from forge.types import Launcher, LauncherConfig |
21 | | -from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints |
22 | 15 | from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport |
23 | 16 | from monarch._rust_bindings.monarch_hyperactor.config import configure |
24 | 17 | from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer |
25 | | -from monarch.actor import Actor, endpoint, ProcMesh |
| 18 | +from monarch.actor import ProcMesh |
26 | 19 | from monarch.tools import commands |
27 | | -from monarch.tools.commands import create, info |
28 | 20 | from monarch.tools.components import hyperactor |
29 | | -from monarch.tools.config import Config, Workspace |
30 | | - |
31 | | -_MAST_AVAILABLE = False |
32 | | - |
33 | | -try: |
34 | | - from monarch._src.actor.actor_mesh import current_rank |
35 | | - from monarch._src.actor.meta.allocator import MastAllocator, MastAllocatorConfig |
36 | | - from monarch.tools.components.meta import hyperactor as meta_hyperactor |
37 | | - from torchx.specs import AppState |
38 | | - from torchx.specs.fb.component_helpers import Packages |
| 21 | +from monarch.tools.config import Config |
39 | 22 |
|
40 | | - _MAST_AVAILABLE = True |
41 | | -except ImportError as e: |
42 | | - # This means there is an error with MAST |
43 | | - pass |
44 | 23 |
|
45 | 24 | JOB_NAME_KEY = "job_name" |
46 | 25 | LAUNCHER_KEY = "launcher" |
47 | 26 |
|
48 | 27 |
|
49 | | -def mount_mnt_directory(mount_dst: str) -> None: |
50 | | - """Mounts the MAST remote directory to the specified destination. |
51 | | -
|
52 | | - This function mounts a remote workspace directory that contains huggingface models |
53 | | - and other shared resources needed for training. |
54 | | -
|
55 | | - Args: |
56 | | - mount_dst: Destination path where the directory should be mounted (e.g., "/mnt/wsfuse") |
57 | | - """ |
58 | | - # Sanity check of the mounted directory |
59 | | - sanity_path = os.path.join(mount_dst, "huggingface_models/") |
60 | | - if os.path.exists(sanity_path): |
61 | | - return |
62 | | - |
63 | | - # Otherwise, mount the directory |
64 | | - if not os.path.exists(mount_dst): |
65 | | - os.makedirs(mount_dst, exist_ok=True) |
66 | | - |
67 | | - # Store original LD_LIBRARY_PATH to restore after mounting |
68 | | - original_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") |
69 | | - |
70 | | - try: |
71 | | - clean_env = os.environ.copy() |
72 | | - if "LD_LIBRARY_PATH" in clean_env: |
73 | | - del clean_env["LD_LIBRARY_PATH"] |
74 | | - |
75 | | - subprocess.run( |
76 | | - [ |
77 | | - "/packages/oil.oilfs/oilfs-wrapper", |
78 | | - "ws://ws.ai.pci0ai/genai_fair_llm", |
79 | | - mount_dst, |
80 | | - ], |
81 | | - capture_output=True, |
82 | | - text=True, |
83 | | - check=True, |
84 | | - env=clean_env, |
85 | | - ) |
86 | | - print("Done mounting") |
87 | | - except subprocess.CalledProcessError as e: |
88 | | - print(f"Get error during mounting {e}, Stderr: {e.stderr}, Stdout: {e.stdout}") |
89 | | - finally: |
90 | | - # Restore original LD_LIBRARY_PATH |
91 | | - if original_ld_library_path: |
92 | | - os.environ["LD_LIBRARY_PATH"] = original_ld_library_path |
93 | | - elif "LD_LIBRARY_PATH" in os.environ: |
94 | | - del os.environ["LD_LIBRARY_PATH"] |
95 | | - |
96 | | - assert os.path.exists( |
97 | | - sanity_path |
98 | | - ), f"Did not find directory {sanity_path}; something wrong with mounting." |
99 | | - |
100 | | - |
101 | | -class MastSetupActor(Actor): |
102 | | - @endpoint |
103 | | - def mount(self, mount_dst: str): |
104 | | - point = current_rank() |
105 | | - # The last dimension is the local proc count. |
106 | | - last_label = point.extent.labels[-1] |
107 | | - proc_count = point.size(last_label) |
108 | | - if current_rank().rank % proc_count != 0: |
109 | | - # Only use one rank per host to mount the directory |
110 | | - return |
111 | | - mount_mnt_directory(mount_dst) |
112 | | - |
113 | | - |
114 | | -class BaseLauncher: |
115 | | - async def initialize(self) -> None: |
116 | | - pass |
117 | | - |
118 | | - async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]: |
119 | | - pass |
120 | | - |
121 | | - async def remote_setup(self, procs: ProcMesh) -> None: |
122 | | - pass |
123 | | - |
124 | | - |
125 | 28 | class Slurmlauncher(BaseLauncher): |
126 | 29 | def __init__( |
127 | 30 | self, |
@@ -172,240 +75,18 @@ async def remote_setup(self, procs: ProcMesh) -> None: |
172 | 75 | return |
173 | 76 |
|
174 | 77 |
|
175 | | -class MastLauncher(BaseLauncher): |
176 | | - """Launcher for MAST (Meta's internal cluster scheduler). |
177 | | -
|
178 | | - This launcher supports two modes of operation: |
179 | | -
|
180 | | - 1. Non-detached mode (detached=False): |
181 | | - - Client runs on your local machine/devserver |
182 | | - - Only worker roles (GPU hosts) are launched in MAST |
183 | | - - Client connects to workers remotely via provisioner |
184 | | -
|
185 | | - 2. Detached mode (detached=True): |
186 | | - - Client runs entirely inside MAST as a separate role |
187 | | - - Both client role (CPU-only) and worker roles (GPU) are launched in MAST |
188 | | - - Client role executes the training script with --mode=remote |
189 | | - - Everything runs in the cluster, no client needed on local machine |
190 | | -
|
191 | | - Args: |
192 | | - cfg: Launcher configuration including job name, services, and actors |
193 | | - detached: If True, adds a client role to the MAST job appdef that runs |
194 | | - the training script inside MAST. If False, only launches worker |
195 | | - roles and expects the client to run on local machine. |
196 | | - extra_args: Additional CLI arguments to pass through to the client role. |
197 | | -
|
198 | | - """ |
199 | | - |
200 | | - def __init__( |
201 | | - self, |
202 | | - cfg: LauncherConfig | None = None, |
203 | | - detached: bool = False, |
204 | | - extra_args: list = None, |
205 | | - ): |
206 | | - assert cfg is not None |
207 | | - self.cfg = cfg |
208 | | - self.detached = detached |
209 | | - self.default_monarch_port = 26600 |
210 | | - self.extra_args = extra_args or [] |
211 | | - self.scheduler_name = "mast_conda" |
212 | | - |
213 | | - # TODO: enable taking this from config |
214 | | - self.sku = "gtt_any" |
215 | | - self.timeout_sec = 1 * 60 * 60 # Kill the job if idle for 1 hour |
216 | | - self.user = getpass.getuser() |
217 | | - self.remote_work_dir = "/packages/monarch_default_workspace/workspace/" |
218 | | - self.editable_workspace_paths = [ |
219 | | - f"/data/users/{self.user}/fbsource/fbcode/pytorch/torchforge" |
220 | | - ] |
221 | | - self.job_name = self.cfg.job_name or self.create_job_name() |
222 | | - |
223 | | - async def initialize(self) -> None: |
224 | | - # HostMesh currently requires explicit configuration |
225 | | - # of the underlying transport from client to mesh. |
226 | | - # This can be removed in the future once this has been removed. |
227 | | - configure(default_transport=ChannelTransport.MetaTlsWithHostname) |
228 | | - |
229 | | - async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]: |
230 | | - allocator = MastAllocator( |
231 | | - MastAllocatorConfig( |
232 | | - job_name=self.job_name, |
233 | | - remote_allocator_port=self.default_monarch_port, |
234 | | - ), |
235 | | - ) |
236 | | - alloc_constraints = AllocConstraints( |
237 | | - {MastAllocator.ALLOC_LABEL_TASK_GROUP: name} |
238 | | - ) |
239 | | - |
240 | | - return allocator, alloc_constraints, self.create_server_handle() |
241 | | - |
242 | | - async def remote_setup(self, procs: ProcMesh) -> None: |
243 | | - setup = procs.spawn("mast_setup", MastSetupActor) |
244 | | - await setup.mount.call(mount_dst="/mnt/wsfuse") |
245 | | - |
246 | | - async def launch_mast_job(self): |
247 | | - handle = self.create_server_handle() |
248 | | - server_spec = info(handle) |
249 | | - if server_spec and server_spec.state == AppState.RUNNING: |
250 | | - print(f"Job {self.job_name} is already running. Skipping launch.") |
251 | | - return server_spec |
252 | | - |
253 | | - config = Config( |
254 | | - scheduler="mast_conda", |
255 | | - scheduler_args={ |
256 | | - "hpcIdentity": "hyper_monarch", |
257 | | - "hpcJobOncall": "monarch", |
258 | | - "hpcClusterUuid": "MastGenAICluster", |
259 | | - "rmAttribution": "msl_infra_hw_enab_agentrl", |
260 | | - }, |
261 | | - appdef=self.build_appdef(), |
262 | | - workspace=Workspace( |
263 | | - dirs=[workspace_dir for workspace_dir in self.editable_workspace_paths], |
264 | | - ), |
265 | | - ) |
266 | | - |
267 | | - job_handle = create(config, name=self.job_name) |
268 | | - print( |
269 | | - f"MAST job launched successfully:\n" |
270 | | - f"\033[92mhttps://www.internalfb.com/mlhub/pipelines/runs/mast/{self.job_name}\033[0m" |
271 | | - ) |
272 | | - return job_handle |
273 | | - |
274 | | - def add_additional_packages(self, packages: "Packages") -> "Packages": |
275 | | - packages.add_package("oil.oilfs:stable") |
276 | | - packages.add_package("manifold.manifoldfs:prod") |
277 | | - return packages |
278 | | - |
279 | | - def build_appdef(self) -> specs.AppDef: |
280 | | - # create the app definition for the worker |
281 | | - additional_python_paths = [ |
282 | | - f"{self.remote_work_dir}torchforge", |
283 | | - self.remote_work_dir, |
284 | | - ] |
285 | | - |
286 | | - default_envs = { |
287 | | - **meta_hyperactor.DEFAULT_NVRT_ENVS, |
288 | | - **meta_hyperactor.DEFAULT_NCCL_ENVS, |
289 | | - **meta_hyperactor.DEFAULT_TORCH_ENVS, |
290 | | - **{"TORCHX_RUN_PYTHONPATH": ":".join(additional_python_paths)}, |
291 | | - **{ |
292 | | - "HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS": "600", |
293 | | - "HYPERACTOR_CODE_MAX_FRAME_LENGTH": "1073741824", |
294 | | - "TORCHINDUCTOR_COMPILE_THREADS": "1", |
295 | | - "TORCH_COMPILE_DISABLE": "1", |
296 | | - "TORCHDYNAMO_VERBOSE": "1", |
297 | | - "VLLM_TORCH_COMPILE_LEVEL": "0", |
298 | | - "VLLM_USE_TRITON_FLASH_ATTN": "0", |
299 | | - "HF_HUB_OFFLINE": "1", |
300 | | - "TORCHSTORE_RDMA_ENABLED": "1", |
301 | | - "HF_HOME": "/mnt/wsfuse/teamforge/hf", |
302 | | - "TRANSFORMERS_OFFLINE": "1", |
303 | | - "FUSE_SRC": "ws://ws.ai.pci0ai/genai_fair_llm", |
304 | | - "FUSE_DST": "/mnt/wsfuse", |
305 | | - }, |
306 | | - } |
307 | | - |
308 | | - packages = Packages() |
309 | | - meshes = [] |
310 | | - # Process both services and actors configurations |
311 | | - for mesh_name, service in self.cfg.services.items(): |
312 | | - num_replicas = service.num_replicas |
313 | | - with_gpus = bool(service.with_gpus) |
314 | | - num_hosts = int(service.hosts or 0) |
315 | | - # Create list of mesh names with indices and num_hosts |
316 | | - if with_gpus and num_hosts > 0: |
317 | | - mesh_list = [ |
318 | | - f"{mesh_name}_{i}:{num_hosts}:{self.sku}" |
319 | | - for i in range(num_replicas) |
320 | | - ] |
321 | | - meshes.extend(mesh_list) |
322 | | - |
323 | | - for mesh_name, actor in self.cfg.actors.items(): |
324 | | - num_replicas = 1 |
325 | | - with_gpus = bool(actor.with_gpus) |
326 | | - num_hosts = int(actor.hosts or 0) |
327 | | - # single actors with GPUs |
328 | | - if with_gpus: |
329 | | - meshes.append(f"{mesh_name}:{num_replicas}:{self.sku}") |
330 | | - |
331 | | - appdef = meta_hyperactor.host_mesh_conda( |
332 | | - meshes=meshes, |
333 | | - additional_packages=self.add_additional_packages(packages), |
334 | | - timeout_sec=self.timeout_sec, |
335 | | - env=default_envs, |
336 | | - ) |
337 | | - appdef.metadata["mast"] = { |
338 | | - "HpcJobDefinition": { |
339 | | - "networkAffinity": { |
340 | | - # Ensure colocation |
341 | | - "preferredScope": 3, # DC |
342 | | - "fallbackScope": 3, # REGION |
343 | | - }, |
344 | | - }, |
345 | | - } |
346 | | - |
347 | | - for role in appdef.roles: |
348 | | - role.resource.capabilities["server_sub_types"] = [ |
349 | | - # role.resource.capabilities["server_sub_types"][2] # hardcoded to ROCE |
350 | | - role.resource.capabilities["server_sub_types"][1] # GTT |
351 | | - ] |
352 | | - |
353 | | - # Add client role to run in MAST if in detached mode |
354 | | - if self.detached: |
355 | | - client_role = self._create_client_role(appdef) |
356 | | - appdef.roles.insert(0, client_role) |
357 | | - |
358 | | - return appdef |
359 | | - |
360 | | - def _create_client_role(self, appdef: specs.AppDef) -> specs.Role: |
361 | | - # Clone an existing worker role to inherit workspace configuration |
362 | | - if not appdef.roles: |
363 | | - raise ValueError( |
364 | | - "Cannot create client role: no worker roles exist to clone from" |
365 | | - ) |
366 | | - |
367 | | - # Clone the first worker role |
368 | | - client_role = copy.deepcopy(appdef.roles[0]) |
369 | | - |
370 | | - # Override with client-specific configuration |
371 | | - client_role.name = "client" |
372 | | - # Use the bootstrap script as entrypoint |
373 | | - client_role.entrypoint = "workspace/torchforge/fb/mast/client_bootstrap.sh" |
374 | | - |
375 | | - # Build args for the client role (passed to the bootstrap script) |
376 | | - # These args will be passed to client_bootstrap.sh which forwards them to main.py |
377 | | - args = [ |
378 | | - "--mode=remote", |
379 | | - "--job-name", |
380 | | - self.job_name, |
381 | | - ] |
382 | | - |
383 | | - # Add any extra args passed from the CLI (includes --config and other args) |
384 | | - if self.extra_args: |
385 | | - args.extend(self.extra_args) |
386 | | - |
387 | | - client_role.args = args |
388 | | - client_role.num_replicas = 1 |
389 | | - |
390 | | - return client_role |
391 | | - |
392 | | - def create_job_name(self): |
393 | | - return f"{self.user}-forge-{uuid.uuid4().hex[:6]}" |
394 | | - |
395 | | - def create_server_handle(self) -> str: |
396 | | - return f"{self.scheduler_name}:///{self.job_name}" |
397 | | - |
398 | | - |
399 | 78 | def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None: |
400 | 79 | if not cfg: |
401 | 80 | return None |
402 | 81 | if cfg.launcher == Launcher.SLURM: |
403 | 82 | return Slurmlauncher(cfg) |
404 | 83 | elif cfg.launcher == Launcher.MAST: |
405 | | - if not _MAST_AVAILABLE: |
406 | | - raise ValueError( |
407 | | - "MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation" |
408 | | - ) |
409 | | - return MastLauncher(cfg, detached=False) |
| 84 | + try: |
| 85 | + from forge.fb.mast_launcher import MastLauncher |
| 86 | + |
| 87 | + return MastLauncher(cfg, detached=False) |
| 88 | + except ImportError as err: |
| 89 | + raise ValueError("MAST is not available, cannot launch MAST jobs.") from err |
| 90 | + |
410 | 91 | else: |
411 | 92 | raise ValueError(f"Unsupported config provided, got {cfg}") |
0 commit comments