|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Weight sync benchmark for torchforge generators. |
| 8 | +
|
| 9 | +Measures the time for weight synchronization between trainer and generator, |
| 10 | +with and without shared memory prefetching enabled. |
| 11 | +
|
| 12 | +Example usage: |
| 13 | + # Basic benchmark (no prefetch) |
| 14 | + python -m benchmarks.generator.weight_sync --config apps/grpo/qwen3_8b.yaml |
| 15 | +
|
| 16 | + # With prefetch enabled |
| 17 | + python -m benchmarks.generator.weight_sync \ |
| 18 | + --config apps/grpo/qwen3_8b.yaml \ |
| 19 | + benchmark.prefetch_enabled=true \ |
| 20 | + benchmark.n_fetcher_procs=4 \ |
| 21 | + benchmark.iterations=5 |
| 22 | +""" |
| 23 | + |
| 24 | +import asyncio |
| 25 | +import logging |
| 26 | +import os |
| 27 | +import time |
| 28 | +from dataclasses import dataclass, field |
| 29 | + |
| 30 | +import torch |
| 31 | +import torchstore as ts |
| 32 | +from forge.actors.generator import Generator |
| 33 | +from forge.actors.trainer import TitanTrainer |
| 34 | +from forge.controller.provisioner import init_provisioner, shutdown |
| 35 | +from forge.controller.service.service import uuid |
| 36 | +from forge.types import LauncherConfig, ProvisionerConfig |
| 37 | +from forge.util.config import parse, resolve_hf_hub_paths |
| 38 | +from monarch.actor import endpoint |
| 39 | +from omegaconf import DictConfig |
| 40 | + |
| 41 | +os.environ.setdefault("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS", "600") |
| 42 | +os.environ.setdefault("HYPERACTOR_CODE_MAX_FRAME_LENGTH", "1073741824") |
| 43 | + |
| 44 | +logger = logging.getLogger(__name__) |
| 45 | +logging.basicConfig(level=logging.INFO) |
| 46 | + |
| 47 | + |
| 48 | +class BenchmarkTitanTrainer(TitanTrainer): |
| 49 | + """TitanTrainer with weight modification capabilities for benchmarking.""" |
| 50 | + |
| 51 | + @endpoint |
| 52 | + async def modify_weights(self): |
| 53 | + """Scale all model weights by a factor (simulates training step).""" |
| 54 | + scale: float = 1.001 |
| 55 | + for model_part in self.engine.model_parts: |
| 56 | + sd = model_part.state_dict() |
| 57 | + for k in sd.keys(): |
| 58 | + if torch.is_floating_point(sd[k]): |
| 59 | + sd[k] *= scale |
| 60 | + |
| 61 | + @endpoint |
| 62 | + async def get_model_size_bytes(self) -> int: |
| 63 | + """Get total model size in bytes across all model parts.""" |
| 64 | + total_bytes = 0 |
| 65 | + for model_part in self.engine.model_parts: |
| 66 | + for param in model_part.parameters(): |
| 67 | + total_bytes += param.numel() * param.element_size() |
| 68 | + return total_bytes |
| 69 | + |
| 70 | + |
| 71 | +@dataclass |
| 72 | +class WeightSyncMetrics: |
| 73 | + """Metrics from a single weight sync operation.""" |
| 74 | + |
| 75 | + version: int |
| 76 | + total_time_s: float |
| 77 | + push_time_s: float |
| 78 | + update_time_s: float |
| 79 | + prefetch_enabled: bool |
| 80 | + |
| 81 | + |
| 82 | +@dataclass |
| 83 | +class BenchmarkResults: |
| 84 | + """Aggregated benchmark results.""" |
| 85 | + |
| 86 | + model: str |
| 87 | + iterations: int |
| 88 | + prefetch_enabled: bool |
| 89 | + n_fetcher_procs: int |
| 90 | + model_size_bytes: int = 0 |
| 91 | + metrics: list[WeightSyncMetrics] = field(default_factory=list) |
| 92 | + |
| 93 | + @property |
| 94 | + def model_size_gb(self) -> float: |
| 95 | + return self.model_size_bytes / (1024**3) |
| 96 | + |
| 97 | + @property |
| 98 | + def avg_total_time_s(self) -> float: |
| 99 | + if not self.metrics: |
| 100 | + return 0.0 |
| 101 | + return sum(m.total_time_s for m in self.metrics) / len(self.metrics) |
| 102 | + |
| 103 | + @property |
| 104 | + def avg_push_time_s(self) -> float: |
| 105 | + if not self.metrics: |
| 106 | + return 0.0 |
| 107 | + return sum(m.push_time_s for m in self.metrics) / len(self.metrics) |
| 108 | + |
| 109 | + @property |
| 110 | + def avg_update_time_s(self) -> float: |
| 111 | + if not self.metrics: |
| 112 | + return 0.0 |
| 113 | + return sum(m.update_time_s for m in self.metrics) / len(self.metrics) |
| 114 | + |
| 115 | + @property |
| 116 | + def push_throughput_gb_s(self) -> float: |
| 117 | + if self.avg_push_time_s <= 0 or self.model_size_bytes <= 0: |
| 118 | + return 0.0 |
| 119 | + return self.model_size_gb / self.avg_push_time_s |
| 120 | + |
| 121 | + @property |
| 122 | + def update_throughput_gb_s(self) -> float: |
| 123 | + if self.avg_update_time_s <= 0 or self.model_size_bytes <= 0: |
| 124 | + return 0.0 |
| 125 | + return self.model_size_gb / self.avg_update_time_s |
| 126 | + |
| 127 | + |
| 128 | +def print_results(results: BenchmarkResults): |
| 129 | + """Print benchmark results.""" |
| 130 | + print("\n" + "=" * 80) |
| 131 | + print("WEIGHT SYNC BENCHMARK RESULTS") |
| 132 | + print("=" * 80) |
| 133 | + print(f"Model: {results.model}") |
| 134 | + print(f"Model size: {results.model_size_gb:.2f} GB") |
| 135 | + print(f"Iterations: {results.iterations}") |
| 136 | + print(f"Prefetch enabled: {results.prefetch_enabled}") |
| 137 | + if results.prefetch_enabled: |
| 138 | + print(f"Fetcher procs: {results.n_fetcher_procs}") |
| 139 | + print("-" * 80) |
| 140 | + print(f"{'Metric':<30} {'Time (s)':<15} {'Throughput (GB/s)':<20}") |
| 141 | + print("-" * 80) |
| 142 | + print( |
| 143 | + f"{'Avg push_weights':<30} {results.avg_push_time_s:>12.3f} s " |
| 144 | + f"{results.push_throughput_gb_s:>12.2f} GB/s" |
| 145 | + ) |
| 146 | + print( |
| 147 | + f"{'Avg update_weights':<30} {results.avg_update_time_s:>12.3f} s " |
| 148 | + f"{results.update_throughput_gb_s:>12.2f} GB/s" |
| 149 | + ) |
| 150 | + print(f"{'Avg total (push + update)':<30} {results.avg_total_time_s:>12.3f} s") |
| 151 | + print("=" * 80 + "\n") |
| 152 | + |
| 153 | + |
| 154 | +async def run_weight_sync_benchmark( |
| 155 | + cfg: DictConfig, |
| 156 | + iterations: int, |
| 157 | + prefetch_enabled: bool, |
| 158 | + n_fetcher_procs: int, |
| 159 | + warmup_iterations: int, |
| 160 | +) -> BenchmarkResults: |
| 161 | + """Run weight sync benchmark with knobs to enable prefetch, fetcher procs, etc. |
| 162 | +
|
| 163 | + Args: |
| 164 | + cfg: TorchForge config from YAML |
| 165 | + iterations: Number of weight sync iterations to benchmark |
| 166 | + prefetch_enabled: Whether to enable shared memory prefetching |
| 167 | + n_fetcher_procs: Number of fetcher processes (when prefetch_enabled=True) |
| 168 | + warmup_iterations: Number of warmup iterations before timing |
| 169 | +
|
| 170 | + Returns: |
| 171 | + BenchmarkResults with timing metrics |
| 172 | + """ |
| 173 | + model_name = cfg.generator.engine_args.get("model", "unknown") |
| 174 | + |
| 175 | + generator_cfg = cfg.generator.copy() |
| 176 | + if prefetch_enabled: |
| 177 | + generator_cfg.prefetch_weights_to_shm = True |
| 178 | + generator_cfg.n_fetcher_procs = n_fetcher_procs |
| 179 | + else: |
| 180 | + generator_cfg.prefetch_weights_to_shm = False |
| 181 | + generator_cfg.n_fetcher_procs = 0 |
| 182 | + |
| 183 | + if cfg.get("provisioner", None) is not None: |
| 184 | + provisioner = await init_provisioner( |
| 185 | + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) |
| 186 | + ) |
| 187 | + else: |
| 188 | + provisioner = await init_provisioner() |
| 189 | + |
| 190 | + services_generator_cfg = cfg.services.generator.copy() |
| 191 | + services_generator_cfg.num_replicas = 1 |
| 192 | + |
| 193 | + logger.info("Spawning Generator and Trainer...") |
| 194 | + generator, trainer = await asyncio.gather( |
| 195 | + Generator.options(**services_generator_cfg).as_service(**generator_cfg), |
| 196 | + BenchmarkTitanTrainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer), |
| 197 | + ) |
| 198 | + logger.info("Generator and Trainer spawned.") |
| 199 | + |
| 200 | + trainer_num_procs = cfg.actors.trainer["procs"] |
| 201 | + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] |
| 202 | + trainer_hosts = await provisioner.get_host_mesh(trainer_host_mesh_name) |
| 203 | + # same as the main grpo app. |
| 204 | + await ts.initialize( |
| 205 | + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), |
| 206 | + strategy=ts.LocalRankStrategy(), |
| 207 | + ) |
| 208 | + logger.info("Torchstore initialized with LocalRankStrategy") |
| 209 | + |
| 210 | + if warmup_iterations > 0: |
| 211 | + logger.info(f"Running {warmup_iterations} warmup iteration(s)...") |
| 212 | + for i in range(warmup_iterations): |
| 213 | + v = uuid.uuid4().int |
| 214 | + await trainer.push_weights.call(policy_version=v) |
| 215 | + await generator.update_weights.fanout(version=v) |
| 216 | + await trainer.modify_weights.call() |
| 217 | + logger.info("Warmup complete.") |
| 218 | + |
| 219 | + # Get model size for throughput calculation |
| 220 | + # With DTensor/TP, each rank's param.numel() returns global size, not shard size |
| 221 | + # So just take one rank's value |
| 222 | + model_size_result = await trainer.get_model_size_bytes.call() |
| 223 | + _, model_size_bytes = next(iter(model_size_result.items())) |
| 224 | + model_size_gb = model_size_bytes / (1024**3) |
| 225 | + logger.info(f"Model size: {model_size_gb:.2f} GB") |
| 226 | + |
| 227 | + logger.info(f"Running {iterations} timed iteration(s)...") |
| 228 | + metrics: list[WeightSyncMetrics] = [] |
| 229 | + |
| 230 | + # Generate a test prompt for in-flight requests |
| 231 | + test_prompt = "What is the capital of France? Please explain in detail." |
| 232 | + |
| 233 | + for i in range(iterations): |
| 234 | + v = uuid.uuid4().int |
| 235 | + |
| 236 | + # Modify weights to simulate training |
| 237 | + await trainer.modify_weights.call() |
| 238 | + |
| 239 | + # Time push_weights |
| 240 | + push_start = time.perf_counter() |
| 241 | + await trainer.push_weights.call(policy_version=v) |
| 242 | + push_end = time.perf_counter() |
| 243 | + push_time_s = push_end - push_start |
| 244 | + |
| 245 | + # Simulate in-flight requests that pause_generation must wait for |
| 246 | + num_inflight = 4 |
| 247 | + generation_tasks = [ |
| 248 | + asyncio.create_task(generator.generate.route(test_prompt)) |
| 249 | + for _ in range(num_inflight) |
| 250 | + ] |
| 251 | + # Give generation a moment to start |
| 252 | + await asyncio.sleep(0.1) |
| 253 | + |
| 254 | + # Time update_weights (includes pause_generation waiting for in-flight) |
| 255 | + update_start = time.perf_counter() |
| 256 | + await generator.update_weights.fanout(version=v) |
| 257 | + update_end = time.perf_counter() |
| 258 | + update_time_s = update_end - update_start |
| 259 | + |
| 260 | + # Wait for generation to complete (after weight update) |
| 261 | + await asyncio.gather(*generation_tasks) |
| 262 | + |
| 263 | + total_time_s = push_time_s + update_time_s |
| 264 | + |
| 265 | + metrics.append( |
| 266 | + WeightSyncMetrics( |
| 267 | + version=v, |
| 268 | + total_time_s=total_time_s, |
| 269 | + push_time_s=push_time_s, |
| 270 | + update_time_s=update_time_s, |
| 271 | + prefetch_enabled=prefetch_enabled, |
| 272 | + ) |
| 273 | + ) |
| 274 | + |
| 275 | + logger.info( |
| 276 | + f"Iteration {i + 1}/{iterations}: push={push_time_s:.3f}s, " |
| 277 | + f"update={update_time_s:.3f}s, total={total_time_s:.3f}s" |
| 278 | + ) |
| 279 | + |
| 280 | + logger.info("Cleaning up...") |
| 281 | + await trainer.cleanup.call() |
| 282 | + await generator.shutdown() |
| 283 | + await BenchmarkTitanTrainer.shutdown(trainer) |
| 284 | + await ts.shutdown() |
| 285 | + |
| 286 | + return BenchmarkResults( |
| 287 | + model=model_name, |
| 288 | + iterations=iterations, |
| 289 | + prefetch_enabled=prefetch_enabled, |
| 290 | + n_fetcher_procs=n_fetcher_procs if prefetch_enabled else 0, |
| 291 | + model_size_bytes=model_size_bytes, |
| 292 | + metrics=metrics, |
| 293 | + ) |
| 294 | + |
| 295 | + |
| 296 | +@parse |
| 297 | +def recipe_main(cfg: DictConfig = None) -> None: # type: ignore[assignment] |
| 298 | + """Main entry point for weight sync benchmark. |
| 299 | +
|
| 300 | + Args: |
| 301 | + cfg: Config loaded from YAML file via @parse decorator. |
| 302 | + Benchmark parameters can be specified via key=value overrides: |
| 303 | + benchmark.iterations=5 |
| 304 | + benchmark.prefetch_enabled=true |
| 305 | + benchmark.n_fetcher_procs=4 |
| 306 | + benchmark.warmup_iterations=1 |
| 307 | + """ |
| 308 | + cfg = resolve_hf_hub_paths(cfg) |
| 309 | + |
| 310 | + benchmark_cfg = cfg.get("benchmark", {}) |
| 311 | + iterations = benchmark_cfg.get("iterations", 3) |
| 312 | + prefetch_enabled = benchmark_cfg.get("prefetch_enabled", False) |
| 313 | + n_fetcher_procs = benchmark_cfg.get("n_fetcher_procs", 8) |
| 314 | + warmup_iterations = benchmark_cfg.get("warmup_iterations", 1) |
| 315 | + |
| 316 | + results = asyncio.run( |
| 317 | + run_weight_sync_benchmark( |
| 318 | + cfg=cfg, |
| 319 | + iterations=iterations, |
| 320 | + prefetch_enabled=prefetch_enabled, |
| 321 | + n_fetcher_procs=n_fetcher_procs, |
| 322 | + warmup_iterations=warmup_iterations, |
| 323 | + ) |
| 324 | + ) |
| 325 | + print_results(results) |
| 326 | + |
| 327 | + asyncio.run(shutdown()) |
| 328 | + |
| 329 | + |
| 330 | +if __name__ == "__main__": |
| 331 | + recipe_main() |
0 commit comments