Skip to content

Commit 2729bdc

Browse files
authored
prefetch weights while waiting for pending requests to complete (#728)
1 parent d232f31 commit 2729bdc

File tree

5 files changed

+590
-18
lines changed

5 files changed

+590
-18
lines changed
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Weight sync benchmark for torchforge generators.
8+
9+
Measures the time for weight synchronization between trainer and generator,
10+
with and without shared memory prefetching enabled.
11+
12+
Example usage:
13+
# Basic benchmark (no prefetch)
14+
python -m benchmarks.generator.weight_sync --config apps/grpo/qwen3_8b.yaml
15+
16+
# With prefetch enabled
17+
python -m benchmarks.generator.weight_sync \
18+
--config apps/grpo/qwen3_8b.yaml \
19+
benchmark.prefetch_enabled=true \
20+
benchmark.n_fetcher_procs=4 \
21+
benchmark.iterations=5
22+
"""
23+
24+
import asyncio
25+
import logging
26+
import os
27+
import time
28+
from dataclasses import dataclass, field
29+
30+
import torch
31+
import torchstore as ts
32+
from forge.actors.generator import Generator
33+
from forge.actors.trainer import TitanTrainer
34+
from forge.controller.provisioner import init_provisioner, shutdown
35+
from forge.controller.service.service import uuid
36+
from forge.types import LauncherConfig, ProvisionerConfig
37+
from forge.util.config import parse, resolve_hf_hub_paths
38+
from monarch.actor import endpoint
39+
from omegaconf import DictConfig
40+
41+
os.environ.setdefault("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS", "600")
42+
os.environ.setdefault("HYPERACTOR_CODE_MAX_FRAME_LENGTH", "1073741824")
43+
44+
logger = logging.getLogger(__name__)
45+
logging.basicConfig(level=logging.INFO)
46+
47+
48+
class BenchmarkTitanTrainer(TitanTrainer):
49+
"""TitanTrainer with weight modification capabilities for benchmarking."""
50+
51+
@endpoint
52+
async def modify_weights(self):
53+
"""Scale all model weights by a factor (simulates training step)."""
54+
scale: float = 1.001
55+
for model_part in self.engine.model_parts:
56+
sd = model_part.state_dict()
57+
for k in sd.keys():
58+
if torch.is_floating_point(sd[k]):
59+
sd[k] *= scale
60+
61+
@endpoint
62+
async def get_model_size_bytes(self) -> int:
63+
"""Get total model size in bytes across all model parts."""
64+
total_bytes = 0
65+
for model_part in self.engine.model_parts:
66+
for param in model_part.parameters():
67+
total_bytes += param.numel() * param.element_size()
68+
return total_bytes
69+
70+
71+
@dataclass
72+
class WeightSyncMetrics:
73+
"""Metrics from a single weight sync operation."""
74+
75+
version: int
76+
total_time_s: float
77+
push_time_s: float
78+
update_time_s: float
79+
prefetch_enabled: bool
80+
81+
82+
@dataclass
83+
class BenchmarkResults:
84+
"""Aggregated benchmark results."""
85+
86+
model: str
87+
iterations: int
88+
prefetch_enabled: bool
89+
n_fetcher_procs: int
90+
model_size_bytes: int = 0
91+
metrics: list[WeightSyncMetrics] = field(default_factory=list)
92+
93+
@property
94+
def model_size_gb(self) -> float:
95+
return self.model_size_bytes / (1024**3)
96+
97+
@property
98+
def avg_total_time_s(self) -> float:
99+
if not self.metrics:
100+
return 0.0
101+
return sum(m.total_time_s for m in self.metrics) / len(self.metrics)
102+
103+
@property
104+
def avg_push_time_s(self) -> float:
105+
if not self.metrics:
106+
return 0.0
107+
return sum(m.push_time_s for m in self.metrics) / len(self.metrics)
108+
109+
@property
110+
def avg_update_time_s(self) -> float:
111+
if not self.metrics:
112+
return 0.0
113+
return sum(m.update_time_s for m in self.metrics) / len(self.metrics)
114+
115+
@property
116+
def push_throughput_gb_s(self) -> float:
117+
if self.avg_push_time_s <= 0 or self.model_size_bytes <= 0:
118+
return 0.0
119+
return self.model_size_gb / self.avg_push_time_s
120+
121+
@property
122+
def update_throughput_gb_s(self) -> float:
123+
if self.avg_update_time_s <= 0 or self.model_size_bytes <= 0:
124+
return 0.0
125+
return self.model_size_gb / self.avg_update_time_s
126+
127+
128+
def print_results(results: BenchmarkResults):
129+
"""Print benchmark results."""
130+
print("\n" + "=" * 80)
131+
print("WEIGHT SYNC BENCHMARK RESULTS")
132+
print("=" * 80)
133+
print(f"Model: {results.model}")
134+
print(f"Model size: {results.model_size_gb:.2f} GB")
135+
print(f"Iterations: {results.iterations}")
136+
print(f"Prefetch enabled: {results.prefetch_enabled}")
137+
if results.prefetch_enabled:
138+
print(f"Fetcher procs: {results.n_fetcher_procs}")
139+
print("-" * 80)
140+
print(f"{'Metric':<30} {'Time (s)':<15} {'Throughput (GB/s)':<20}")
141+
print("-" * 80)
142+
print(
143+
f"{'Avg push_weights':<30} {results.avg_push_time_s:>12.3f} s "
144+
f"{results.push_throughput_gb_s:>12.2f} GB/s"
145+
)
146+
print(
147+
f"{'Avg update_weights':<30} {results.avg_update_time_s:>12.3f} s "
148+
f"{results.update_throughput_gb_s:>12.2f} GB/s"
149+
)
150+
print(f"{'Avg total (push + update)':<30} {results.avg_total_time_s:>12.3f} s")
151+
print("=" * 80 + "\n")
152+
153+
154+
async def run_weight_sync_benchmark(
155+
cfg: DictConfig,
156+
iterations: int,
157+
prefetch_enabled: bool,
158+
n_fetcher_procs: int,
159+
warmup_iterations: int,
160+
) -> BenchmarkResults:
161+
"""Run weight sync benchmark with knobs to enable prefetch, fetcher procs, etc.
162+
163+
Args:
164+
cfg: TorchForge config from YAML
165+
iterations: Number of weight sync iterations to benchmark
166+
prefetch_enabled: Whether to enable shared memory prefetching
167+
n_fetcher_procs: Number of fetcher processes (when prefetch_enabled=True)
168+
warmup_iterations: Number of warmup iterations before timing
169+
170+
Returns:
171+
BenchmarkResults with timing metrics
172+
"""
173+
model_name = cfg.generator.engine_args.get("model", "unknown")
174+
175+
generator_cfg = cfg.generator.copy()
176+
if prefetch_enabled:
177+
generator_cfg.prefetch_weights_to_shm = True
178+
generator_cfg.n_fetcher_procs = n_fetcher_procs
179+
else:
180+
generator_cfg.prefetch_weights_to_shm = False
181+
generator_cfg.n_fetcher_procs = 0
182+
183+
if cfg.get("provisioner", None) is not None:
184+
provisioner = await init_provisioner(
185+
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
186+
)
187+
else:
188+
provisioner = await init_provisioner()
189+
190+
services_generator_cfg = cfg.services.generator.copy()
191+
services_generator_cfg.num_replicas = 1
192+
193+
logger.info("Spawning Generator and Trainer...")
194+
generator, trainer = await asyncio.gather(
195+
Generator.options(**services_generator_cfg).as_service(**generator_cfg),
196+
BenchmarkTitanTrainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer),
197+
)
198+
logger.info("Generator and Trainer spawned.")
199+
200+
trainer_num_procs = cfg.actors.trainer["procs"]
201+
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
202+
trainer_hosts = await provisioner.get_host_mesh(trainer_host_mesh_name)
203+
# same as the main grpo app.
204+
await ts.initialize(
205+
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
206+
strategy=ts.LocalRankStrategy(),
207+
)
208+
logger.info("Torchstore initialized with LocalRankStrategy")
209+
210+
if warmup_iterations > 0:
211+
logger.info(f"Running {warmup_iterations} warmup iteration(s)...")
212+
for i in range(warmup_iterations):
213+
v = uuid.uuid4().int
214+
await trainer.push_weights.call(policy_version=v)
215+
await generator.update_weights.fanout(version=v)
216+
await trainer.modify_weights.call()
217+
logger.info("Warmup complete.")
218+
219+
# Get model size for throughput calculation
220+
# With DTensor/TP, each rank's param.numel() returns global size, not shard size
221+
# So just take one rank's value
222+
model_size_result = await trainer.get_model_size_bytes.call()
223+
_, model_size_bytes = next(iter(model_size_result.items()))
224+
model_size_gb = model_size_bytes / (1024**3)
225+
logger.info(f"Model size: {model_size_gb:.2f} GB")
226+
227+
logger.info(f"Running {iterations} timed iteration(s)...")
228+
metrics: list[WeightSyncMetrics] = []
229+
230+
# Generate a test prompt for in-flight requests
231+
test_prompt = "What is the capital of France? Please explain in detail."
232+
233+
for i in range(iterations):
234+
v = uuid.uuid4().int
235+
236+
# Modify weights to simulate training
237+
await trainer.modify_weights.call()
238+
239+
# Time push_weights
240+
push_start = time.perf_counter()
241+
await trainer.push_weights.call(policy_version=v)
242+
push_end = time.perf_counter()
243+
push_time_s = push_end - push_start
244+
245+
# Simulate in-flight requests that pause_generation must wait for
246+
num_inflight = 4
247+
generation_tasks = [
248+
asyncio.create_task(generator.generate.route(test_prompt))
249+
for _ in range(num_inflight)
250+
]
251+
# Give generation a moment to start
252+
await asyncio.sleep(0.1)
253+
254+
# Time update_weights (includes pause_generation waiting for in-flight)
255+
update_start = time.perf_counter()
256+
await generator.update_weights.fanout(version=v)
257+
update_end = time.perf_counter()
258+
update_time_s = update_end - update_start
259+
260+
# Wait for generation to complete (after weight update)
261+
await asyncio.gather(*generation_tasks)
262+
263+
total_time_s = push_time_s + update_time_s
264+
265+
metrics.append(
266+
WeightSyncMetrics(
267+
version=v,
268+
total_time_s=total_time_s,
269+
push_time_s=push_time_s,
270+
update_time_s=update_time_s,
271+
prefetch_enabled=prefetch_enabled,
272+
)
273+
)
274+
275+
logger.info(
276+
f"Iteration {i + 1}/{iterations}: push={push_time_s:.3f}s, "
277+
f"update={update_time_s:.3f}s, total={total_time_s:.3f}s"
278+
)
279+
280+
logger.info("Cleaning up...")
281+
await trainer.cleanup.call()
282+
await generator.shutdown()
283+
await BenchmarkTitanTrainer.shutdown(trainer)
284+
await ts.shutdown()
285+
286+
return BenchmarkResults(
287+
model=model_name,
288+
iterations=iterations,
289+
prefetch_enabled=prefetch_enabled,
290+
n_fetcher_procs=n_fetcher_procs if prefetch_enabled else 0,
291+
model_size_bytes=model_size_bytes,
292+
metrics=metrics,
293+
)
294+
295+
296+
@parse
297+
def recipe_main(cfg: DictConfig = None) -> None: # type: ignore[assignment]
298+
"""Main entry point for weight sync benchmark.
299+
300+
Args:
301+
cfg: Config loaded from YAML file via @parse decorator.
302+
Benchmark parameters can be specified via key=value overrides:
303+
benchmark.iterations=5
304+
benchmark.prefetch_enabled=true
305+
benchmark.n_fetcher_procs=4
306+
benchmark.warmup_iterations=1
307+
"""
308+
cfg = resolve_hf_hub_paths(cfg)
309+
310+
benchmark_cfg = cfg.get("benchmark", {})
311+
iterations = benchmark_cfg.get("iterations", 3)
312+
prefetch_enabled = benchmark_cfg.get("prefetch_enabled", False)
313+
n_fetcher_procs = benchmark_cfg.get("n_fetcher_procs", 8)
314+
warmup_iterations = benchmark_cfg.get("warmup_iterations", 1)
315+
316+
results = asyncio.run(
317+
run_weight_sync_benchmark(
318+
cfg=cfg,
319+
iterations=iterations,
320+
prefetch_enabled=prefetch_enabled,
321+
n_fetcher_procs=n_fetcher_procs,
322+
warmup_iterations=warmup_iterations,
323+
)
324+
)
325+
print_results(results)
326+
327+
asyncio.run(shutdown())
328+
329+
330+
if __name__ == "__main__":
331+
recipe_main()

src/forge/actors/vllm/v0/generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from collections.abc import Mapping
1515
from copy import copy
1616
from dataclasses import dataclass, field
17+
from multiprocessing import resource_tracker
1718
from typing import Optional
1819

1920
import torch
@@ -727,10 +728,15 @@ async def fetch(
727728
sd = {}
728729
for name in param_names:
729730
param_key = get_param_key(version, name)
731+
# Use explicit resource handling instead of context manager because
732+
# ownership is transferred to the Generator (which calls handle.drop()
733+
# to clean up). We must unregister from resource_tracker here, otherwise
734+
# the fetcher process will try to clean up the shared memory on exit.
730735
param = await ts.get(param_key)
731-
# Use context manager to ensure cleanup after getting handle
732-
with SharedTensor(tensor=param) as shared_tensor:
733-
handle = shared_tensor.get_handle()
734-
sd[name] = handle
736+
shared_tensor = SharedTensor(tensor=param)
737+
handle = shared_tensor.get_handle()
738+
resource_tracker.unregister(f"/{handle.shm_name}", "shared_memory")
739+
sd[name] = handle
740+
shared_tensor.close()
735741
del param # Explicitly free the tensor after copying to shared memory
736742
return sd

0 commit comments

Comments
 (0)