Skip to content

Commit f1bedc1

Browse files
Create a HostMesh per replica
Differential Revision: D91890958 Pull Request resolved: #746
1 parent cd9e295 commit f1bedc1

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

src/forge/controller/launcher.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@ def get_meshes_from_config(cfg: LauncherConfig) -> dict[str, int]:
3434
meshes: dict[str, int] = {}
3535

3636
# Add services that need remote hosts
37+
# Expand services with multiple replicas into per-replica meshes
3738
for service_name, service_cfg in cfg.services.items():
3839
hosts = getattr(service_cfg, "hosts", None)
3940
if hosts and hosts > 0:
40-
mesh_name = service_cfg.mesh_name or service_name
41-
meshes[mesh_name] = hosts
41+
base_mesh_name = service_cfg.mesh_name or service_name
42+
num_replicas = service_cfg.num_replicas
43+
for replica_idx in range(num_replicas):
44+
mesh_name = f"{base_mesh_name}_{replica_idx}"
45+
meshes[mesh_name] = hosts
4246

4347
# Add actors that need remote hosts
4448
for actor_name, actor_cfg in cfg.actors.items():
@@ -78,7 +82,7 @@ async def initialize(self) -> tuple[JobTrait, JobState]:
7882
# Create a single SlurmJob with all meshes
7983
logger.info(f"Creating SlurmJob with meshes: {meshes}")
8084
job = SlurmJob(
81-
meshes=meshes, # e.g., {"generator": 1, "trainer": 2, "ref_model": 1}
85+
meshes=meshes, # e.g., {"generator_0": 1, "generator_1": 1, "trainer": 2}
8286
slurm_args=slurm_args,
8387
job_name=self.cfg.job_name + "_workers" or "forge_job",
8488
time_limit="72:00:00", # Default to 72 hours

src/forge/controller/provisioner.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,20 +269,12 @@ async def get_host_mesh(self, name: str) -> HostMesh:
269269
# for local jobs, return the local host
270270
return this_host()
271271

272-
# Strip replica suffix (e.g., "generator_0" -> "generator")
273-
# Services append _{replica_idx} to mesh names
274-
base_name = name
275-
if "_" in name:
276-
parts = name.rsplit("_", 1)
277-
if len(parts) == 2 and parts[1].isdigit():
278-
base_name = parts[0]
279-
280272
# _job_state contains all the HostMeshes that were allocated as attributes, accessible by their name
281273
try:
282-
host_mesh = getattr(self._job_state, base_name)
274+
host_mesh = getattr(self._job_state, name)
283275
except AttributeError as e:
284276
raise RuntimeError(
285-
f"Mesh '{name}' (base: '{base_name}') was not pre-allocated. "
277+
f"Mesh '{name}' was not pre-allocated. "
286278
"Make sure the mesh is defined in the launcher config."
287279
) from e
288280

0 commit comments

Comments
 (0)