|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Forge-specific MonarchExecutor with TorchStore weight sync. |
| 9 | +
|
| 10 | +This module extends the upstream-compatible MonarchExecutor with TorchStore |
| 11 | +integration for weight synchronization in RL training loops. It provides: |
| 12 | +
|
| 13 | +- ForgeWorkerWrapper: Extends WorkerWrapper with TorchStore weight loading |
| 14 | +- ForgeMonarchExecutor: Extends MonarchExecutor with TorchStore Controller handling |
| 15 | +
|
| 16 | +Use this executor when you need weight updates from TorchStore (e.g., GRPO training). |
| 17 | +For inference-only workloads, use the base MonarchExecutor directly. |
| 18 | +
|
| 19 | +TODO: Add shared memory weight prefetch support (prefetch_weights_to_shm, n_fetcher_procs) |
| 20 | + as in v0 Generator for faster weight loading. |
| 21 | +""" |
| 22 | + |
| 23 | +from __future__ import annotations |
| 24 | + |
| 25 | +import asyncio |
| 26 | +import base64 |
| 27 | +import logging |
| 28 | +import os |
| 29 | +from typing import Optional |
| 30 | + |
| 31 | +import cloudpickle |
| 32 | +from forge.actors._torchstore_utils import extract_param_name, get_param_prefix |
| 33 | +from forge.actors.vllm.v1.monarch_executor import MonarchExecutor, WorkerWrapper |
| 34 | +from monarch.actor import endpoint |
| 35 | +from torchstore.client import LocalClient |
| 36 | + |
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
| 39 | + |
| 40 | +class ForgeWorkerWrapper(WorkerWrapper): |
| 41 | + """Worker wrapper with TorchStore weight sync capabilities.""" |
| 42 | + |
| 43 | + def __init__(self, vllm_config): |
| 44 | + super().__init__(vllm_config) |
| 45 | + self._torchstore_controller = None |
| 46 | + self._torchstore_client: Optional[LocalClient] = None |
| 47 | + |
| 48 | + @endpoint |
| 49 | + def set_torchstore_controller(self, controller) -> None: |
| 50 | + """Store TorchStore Controller reference for weight updates. |
| 51 | +
|
| 52 | + Workers run in a subprocess with a different _controller_controller, |
| 53 | + so they can't find the Controller via get_or_spawn_controller. |
| 54 | + The Controller reference is passed explicitly from ForgeMonarchExecutor. |
| 55 | + """ |
| 56 | + self._torchstore_controller = controller |
| 57 | + self._torchstore_client = None # Reset cached client |
| 58 | + |
| 59 | + @endpoint |
| 60 | + def update_weights(self, version: int) -> int: |
| 61 | + """Load weights directly from torchstore. |
| 62 | +
|
| 63 | + Args: |
| 64 | + version: Policy version to load from torchstore |
| 65 | +
|
| 66 | + Returns: |
| 67 | + Number of parameters loaded |
| 68 | + """ |
| 69 | + return asyncio.run(self._load_from_torchstore(version)) |
| 70 | + |
| 71 | + async def _get_torchstore_client(self) -> LocalClient: |
| 72 | + """Get or create a LocalClient using the passed Controller reference. |
| 73 | +
|
| 74 | + Workers can't use ts.client() directly because they're in a subprocess |
| 75 | + with a different _controller_controller. Instead, we create a LocalClient |
| 76 | + using the Controller reference passed from ForgeMonarchExecutor. |
| 77 | + """ |
| 78 | + if self._torchstore_client is not None: |
| 79 | + return self._torchstore_client |
| 80 | + |
| 81 | + if self._torchstore_controller is None: |
| 82 | + raise RuntimeError( |
| 83 | + "TorchStore Controller not set. " |
| 84 | + "ForgeMonarchExecutor must call set_torchstore_controller before weight updates." |
| 85 | + ) |
| 86 | + |
| 87 | + strategy = await self._torchstore_controller.get_controller_strategy.call_one() |
| 88 | + self._torchstore_client = LocalClient( |
| 89 | + controller=self._torchstore_controller, |
| 90 | + strategy=strategy, |
| 91 | + ) |
| 92 | + return self._torchstore_client |
| 93 | + |
| 94 | + async def _load_from_torchstore(self, version: int) -> int: |
| 95 | + """Async helper to load from torchstore using the passed Controller.""" |
| 96 | + client = await self._get_torchstore_client() |
| 97 | + prefix = get_param_prefix(version) |
| 98 | + matching_keys = await client.keys(prefix) |
| 99 | + model = self.worker.model_runner.model |
| 100 | + loaded_count = 0 |
| 101 | + for key in matching_keys: |
| 102 | + name = extract_param_name(key) |
| 103 | + param = await client.get(key) |
| 104 | + model.load_weights([(name, param.cuda())]) |
| 105 | + del param |
| 106 | + loaded_count += 1 |
| 107 | + return loaded_count |
| 108 | + |
| 109 | + @endpoint |
| 110 | + def save_model_params(self): |
| 111 | + """Save model parameters before weight update, used for testing purposes only.""" |
| 112 | + logger.info("[WorkerWrapper] save model parameters for testing.") |
| 113 | + if not hasattr(self, "_test_prev_params"): |
| 114 | + self._test_prev_params = {} |
| 115 | + for name, param in self.worker.model_runner.model.named_parameters(): |
| 116 | + self._test_prev_params[name] = param.detach().cpu() |
| 117 | + logger.info( |
| 118 | + "[WorkerWrapper] finished saving model parameters, len = %d", |
| 119 | + len(self._test_prev_params), |
| 120 | + ) |
| 121 | + |
| 122 | + @endpoint |
| 123 | + def validate_model_params(self, validate_fn): |
| 124 | + """Validate updated model params using validate_fn.""" |
| 125 | + logger.info("[WorkerWrapper] start validating model parameters.") |
| 126 | + if not hasattr(self, "_test_prev_params"): |
| 127 | + self._test_prev_params = {} |
| 128 | + return validate_fn( |
| 129 | + self._test_prev_params, self.worker.model_runner.model, logger |
| 130 | + ) |
| 131 | + |
| 132 | + |
| 133 | +class ForgeMonarchExecutor(MonarchExecutor): |
| 134 | + """MonarchExecutor with TorchStore integration for weight sync. |
| 135 | +
|
| 136 | + Extends the base MonarchExecutor to: |
| 137 | + - Deserialize TorchStore Controller from environment |
| 138 | + - Pass Controller to workers for direct weight loading |
| 139 | + - Use ForgeWorkerWrapper instead of base WorkerWrapper |
| 140 | + """ |
| 141 | + |
| 142 | + worker_class = ForgeWorkerWrapper |
| 143 | + |
| 144 | + def _init_executor(self) -> None: |
| 145 | + """Initialize executor and deserialize TorchStore Controller.""" |
| 146 | + super()._init_executor() |
| 147 | + |
| 148 | + controller_str = os.environ.get("VLLM_TORCHSTORE_CONTROLLER") |
| 149 | + if controller_str: |
| 150 | + logger.info( |
| 151 | + "[ForgeMonarchExecutor] Deserializing TorchStore Controller from environment..." |
| 152 | + ) |
| 153 | + self.torchstore_controller = cloudpickle.loads( |
| 154 | + base64.b64decode(controller_str) |
| 155 | + ) |
| 156 | + logger.info( |
| 157 | + f"[ForgeMonarchExecutor] TorchStore Controller deserialized: {self.torchstore_controller}" |
| 158 | + ) |
| 159 | + self.workers.set_torchstore_controller.call( |
| 160 | + self.torchstore_controller |
| 161 | + ).get() |
| 162 | + |
| 163 | + else: |
| 164 | + self.torchstore_controller = None |
| 165 | + logger.warning( |
| 166 | + "[ForgeMonarchExecutor] No TorchStore Controller found in environment. " |
| 167 | + "Weight updates via torchstore will not work." |
| 168 | + ) |
0 commit comments