Skip to content

Commit bd40c5e

Browse files
authored
Add TorchStore weight sync to Generator v1 (#710)
1 parent 1351212 commit bd40c5e

File tree

3 files changed

+243
-2
lines changed

3 files changed

+243
-2
lines changed

src/forge/actors/vllm/v1/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,21 @@ def __getattr__(name):
2525
from forge.actors.vllm.v1.monarch_executor import WorkerWrapper
2626

2727
return WorkerWrapper
28+
if name == "ForgeMonarchExecutor":
29+
from forge.actors.vllm.v1.forge_executor import ForgeMonarchExecutor
30+
31+
return ForgeMonarchExecutor
32+
if name == "ForgeWorkerWrapper":
33+
from forge.actors.vllm.v1.forge_executor import ForgeWorkerWrapper
34+
35+
return ForgeWorkerWrapper
2836
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
2937

3038

3139
__all__ = [
3240
"Generator",
3341
"MonarchExecutor",
3442
"WorkerWrapper",
43+
"ForgeMonarchExecutor",
44+
"ForgeWorkerWrapper",
3545
]
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Forge-specific MonarchExecutor with TorchStore weight sync.
9+
10+
This module extends the upstream-compatible MonarchExecutor with TorchStore
11+
integration for weight synchronization in RL training loops. It provides:
12+
13+
- ForgeWorkerWrapper: Extends WorkerWrapper with TorchStore weight loading
14+
- ForgeMonarchExecutor: Extends MonarchExecutor with TorchStore Controller handling
15+
16+
Use this executor when you need weight updates from TorchStore (e.g., GRPO training).
17+
For inference-only workloads, use the base MonarchExecutor directly.
18+
19+
TODO: Add shared memory weight prefetch support (prefetch_weights_to_shm, n_fetcher_procs)
20+
as in v0 Generator for faster weight loading.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import asyncio
26+
import base64
27+
import logging
28+
import os
29+
from typing import Optional
30+
31+
import cloudpickle
32+
from forge.actors._torchstore_utils import extract_param_name, get_param_prefix
33+
from forge.actors.vllm.v1.monarch_executor import MonarchExecutor, WorkerWrapper
34+
from monarch.actor import endpoint
35+
from torchstore.client import LocalClient
36+
37+
logger = logging.getLogger(__name__)
38+
39+
40+
class ForgeWorkerWrapper(WorkerWrapper):
41+
"""Worker wrapper with TorchStore weight sync capabilities."""
42+
43+
def __init__(self, vllm_config):
44+
super().__init__(vllm_config)
45+
self._torchstore_controller = None
46+
self._torchstore_client: Optional[LocalClient] = None
47+
48+
@endpoint
49+
def set_torchstore_controller(self, controller) -> None:
50+
"""Store TorchStore Controller reference for weight updates.
51+
52+
Workers run in a subprocess with a different _controller_controller,
53+
so they can't find the Controller via get_or_spawn_controller.
54+
The Controller reference is passed explicitly from ForgeMonarchExecutor.
55+
"""
56+
self._torchstore_controller = controller
57+
self._torchstore_client = None # Reset cached client
58+
59+
@endpoint
60+
def update_weights(self, version: int) -> int:
61+
"""Load weights directly from torchstore.
62+
63+
Args:
64+
version: Policy version to load from torchstore
65+
66+
Returns:
67+
Number of parameters loaded
68+
"""
69+
return asyncio.run(self._load_from_torchstore(version))
70+
71+
async def _get_torchstore_client(self) -> LocalClient:
72+
"""Get or create a LocalClient using the passed Controller reference.
73+
74+
Workers can't use ts.client() directly because they're in a subprocess
75+
with a different _controller_controller. Instead, we create a LocalClient
76+
using the Controller reference passed from ForgeMonarchExecutor.
77+
"""
78+
if self._torchstore_client is not None:
79+
return self._torchstore_client
80+
81+
if self._torchstore_controller is None:
82+
raise RuntimeError(
83+
"TorchStore Controller not set. "
84+
"ForgeMonarchExecutor must call set_torchstore_controller before weight updates."
85+
)
86+
87+
strategy = await self._torchstore_controller.get_controller_strategy.call_one()
88+
self._torchstore_client = LocalClient(
89+
controller=self._torchstore_controller,
90+
strategy=strategy,
91+
)
92+
return self._torchstore_client
93+
94+
async def _load_from_torchstore(self, version: int) -> int:
95+
"""Async helper to load from torchstore using the passed Controller."""
96+
client = await self._get_torchstore_client()
97+
prefix = get_param_prefix(version)
98+
matching_keys = await client.keys(prefix)
99+
model = self.worker.model_runner.model
100+
loaded_count = 0
101+
for key in matching_keys:
102+
name = extract_param_name(key)
103+
param = await client.get(key)
104+
model.load_weights([(name, param.cuda())])
105+
del param
106+
loaded_count += 1
107+
return loaded_count
108+
109+
@endpoint
110+
def save_model_params(self):
111+
"""Save model parameters before weight update, used for testing purposes only."""
112+
logger.info("[WorkerWrapper] save model parameters for testing.")
113+
if not hasattr(self, "_test_prev_params"):
114+
self._test_prev_params = {}
115+
for name, param in self.worker.model_runner.model.named_parameters():
116+
self._test_prev_params[name] = param.detach().cpu()
117+
logger.info(
118+
"[WorkerWrapper] finished saving model parameters, len = %d",
119+
len(self._test_prev_params),
120+
)
121+
122+
@endpoint
123+
def validate_model_params(self, validate_fn):
124+
"""Validate updated model params using validate_fn."""
125+
logger.info("[WorkerWrapper] start validating model parameters.")
126+
if not hasattr(self, "_test_prev_params"):
127+
self._test_prev_params = {}
128+
return validate_fn(
129+
self._test_prev_params, self.worker.model_runner.model, logger
130+
)
131+
132+
133+
class ForgeMonarchExecutor(MonarchExecutor):
134+
"""MonarchExecutor with TorchStore integration for weight sync.
135+
136+
Extends the base MonarchExecutor to:
137+
- Deserialize TorchStore Controller from environment
138+
- Pass Controller to workers for direct weight loading
139+
- Use ForgeWorkerWrapper instead of base WorkerWrapper
140+
"""
141+
142+
worker_class = ForgeWorkerWrapper
143+
144+
def _init_executor(self) -> None:
145+
"""Initialize executor and deserialize TorchStore Controller."""
146+
super()._init_executor()
147+
148+
controller_str = os.environ.get("VLLM_TORCHSTORE_CONTROLLER")
149+
if controller_str:
150+
logger.info(
151+
"[ForgeMonarchExecutor] Deserializing TorchStore Controller from environment..."
152+
)
153+
self.torchstore_controller = cloudpickle.loads(
154+
base64.b64decode(controller_str)
155+
)
156+
logger.info(
157+
f"[ForgeMonarchExecutor] TorchStore Controller deserialized: {self.torchstore_controller}"
158+
)
159+
self.workers.set_torchstore_controller.call(
160+
self.torchstore_controller
161+
).get()
162+
163+
else:
164+
self.torchstore_controller = None
165+
logger.warning(
166+
"[ForgeMonarchExecutor] No TorchStore Controller found in environment. "
167+
"Weight updates via torchstore will not work."
168+
)

src/forge/actors/vllm/v1/generator.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from forge.data_models.completion import Completion
2222
from forge.data_models.prompt import to_prompt
2323
from monarch.actor import endpoint, this_host
24+
from torchstore.api import _controller as get_torchstore_controller
2425
from vllm.engine.arg_utils import EngineArgs
2526
from vllm.entrypoints.llm import UsageContext
2627
from vllm.outputs import RequestOutput
@@ -197,14 +198,21 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]):
197198
).decode("utf-8")
198199
os.environ["VLLM_MONARCH_WORKER_REGISTRY"] = serialized_registry
199200

201+
# Serialize TorchStore Controller reference for workers to access torchstore
202+
torchstore_controller = await get_torchstore_controller()
203+
serialized_controller = base64.b64encode(
204+
cloudpickle.dumps(torchstore_controller)
205+
).decode("utf-8")
206+
os.environ["VLLM_TORCHSTORE_CONTROLLER"] = serialized_controller
207+
200208
# Force 'spawn' multiprocessing method for Monarch actors.
201209
# This follows vLLM's Ray integration pattern.
202210
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
203211

204-
# Set the executor backend to MonarchExecutor via string path
212+
# Set the executor backend to ForgeMonarchExecutor via string path
205213
# This avoids import deadlock when vLLM spawns EngineCore subprocess
206214
self.vllm_config.parallel_config.distributed_executor_backend = (
207-
"forge.actors.vllm.v1.monarch_executor.MonarchExecutor"
215+
"forge.actors.vllm.v1.forge_executor.ForgeMonarchExecutor"
208216
)
209217
from vllm.v1.executor.abstract import Executor
210218

@@ -309,6 +317,61 @@ async def shutdown(cls, actor):
309317

310318
logger.info("shutdown() complete")
311319

320+
@endpoint
321+
async def update_weights(
322+
self,
323+
version: Optional[int] = None,
324+
) -> None:
325+
"""Update weights on the generator from torchstore.
326+
327+
This method:
328+
1. Pauses generation and waits for in-flight requests to complete
329+
2. Updates weights on workers from torchstore
330+
3. Resumes generation
331+
332+
Note: This is NOT the standard vLLM weight update approach. vLLM typically
333+
uses `collective_rpc` on EngineClient, which internally routes calls to
334+
workers via the executor. However, `collective_rpc` uses msgspec/msgpack
335+
serialization which does not support arbitrary Python objects by default
336+
(only with VLLM_ALLOW_INSECURE_SERIALIZATION=1). This makes it difficult to
337+
pass complex objects like torchstore storage handles. Instead, we use a
338+
monarch-native approach where the Generator actor directly calls the worker
339+
mesh (`self.workers.update_weights`) via Monarch RPC, which uses cloudpickle
340+
and natively supports Monarch actor references for torchstore integration.
341+
342+
Args:
343+
version: Policy version to load from torchstore
344+
"""
345+
if self.llm is None:
346+
raise RuntimeError("Generator not initialized. Call setup() first.")
347+
348+
logger.info(f"Starting weight update to v{version}")
349+
350+
await self.llm.pause_generation(
351+
wait_for_inflight_requests=True, clear_cache=True
352+
)
353+
354+
try:
355+
await self.workers.update_weights.call(version)
356+
self.generator_version = version
357+
logger.info(f"Updated weights from torchstore v{version}")
358+
finally:
359+
await self.llm.resume_generation()
360+
361+
logger.info(f"Weight update complete, now v{version}")
362+
363+
@endpoint
364+
async def save_model_params(self):
365+
"""Save model parameters before weight update, used for testing purposes only."""
366+
logger.info("save model parameters for testing.")
367+
await self.workers.save_model_params.call()
368+
369+
@endpoint
370+
async def validate_model_params(self, validate_fn):
371+
"""Validate updated model params using validate_fn."""
372+
logger.info("start validating model parameters.")
373+
return await self.workers.validate_model_params.call(validate_fn)
374+
312375
def _to_completions(
313376
self, request_output: RequestOutput, prompt: str
314377
) -> list[Completion]:

0 commit comments

Comments
 (0)