@@ -34,11 +34,15 @@ def get_meshes_from_config(cfg: LauncherConfig) -> dict[str, int]:
3434 meshes : dict [str , int ] = {}
3535
3636 # Add services that need remote hosts
37+ # Expand services with multiple replicas into per-replica meshes
3738 for service_name , service_cfg in cfg .services .items ():
3839 hosts = getattr (service_cfg , "hosts" , None )
3940 if hosts and hosts > 0 :
40- mesh_name = service_cfg .mesh_name or service_name
41- meshes [mesh_name ] = hosts
41+ base_mesh_name = service_cfg .mesh_name or service_name
42+ num_replicas = service_cfg .num_replicas
43+ for replica_idx in range (num_replicas ):
44+ mesh_name = f"{ base_mesh_name } _{ replica_idx } "
45+ meshes [mesh_name ] = hosts
4246
4347 # Add actors that need remote hosts
4448 for actor_name , actor_cfg in cfg .actors .items ():
@@ -78,7 +82,7 @@ async def initialize(self) -> tuple[JobTrait, JobState]:
7882 # Create a single SlurmJob with all meshes
7983 logger .info (f"Creating SlurmJob with meshes: { meshes } " )
8084 job = SlurmJob (
81- meshes = meshes , # e.g., {"generator ": 1, "trainer ": 2 , "ref_model ": 1 }
85+ meshes = meshes , # e.g., {"generator_0 ": 1, "generator_1 ": 1 , "trainer ": 2 }
8286 slurm_args = slurm_args ,
8387 job_name = self .cfg .job_name + "_workers" or "forge_job" ,
8488 time_limit = "72:00:00" , # Default to 72 hours
0 commit comments