Skip to content

Introduce Generator v1 with AsyncLLM#709

Merged
JenniferWang merged 1 commit intomainfrom
export-D90280578
Jan 22, 2026
Merged

Introduce Generator v1 with AsyncLLM#709
JenniferWang merged 1 commit intomainfrom
export-D90280578

Conversation

@JenniferWang
Copy link
Copy Markdown
Contributor

@JenniferWang JenniferWang commented Jan 17, 2026

Summary:

Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0.

Functionality wise, this diff implements:

  • Single-node TP (unoptimized, TCP-based proc communication)
  • Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP

  • Unix socket-based communication (instead of TCP)
  • Weight sync integration
  • Logging integration

After that, we can introduce Pipeline Parallelism:

  • Extend executor to capture stage graph (DAG-like execution pattern)

Decisions 1: Integration Layer -- AsyncLLM

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons

  1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
  2. Better fit for agentic RL: The offline LLM class batches requests synchronously via llm.generate([prompts]). AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

Note that VeRL (https://github.com/volcengine/verl/tree/main/verl) integrates with vLLM at the LLM layer via external_launcher where vLLM assumes procs are managed by VeRL (aka ActorRolloutRefWorker). This, however, limits the generator to be "offline-inference" mode which does not seem to fit the Agentic RL workflow. Very importantly, "offline-inference" does not support continuous batching. Therefore, despite the much simpler implementation, we choose the current route. But keep in mind that we don't have hard number on the difference.

Decision 2: Extension Points -- Executor + WorkerWrapperBase

Class Base Class Location Purpose
MonarchExecutor vllm.v1.executor.abstract.Executor monarch_executor.py Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.
WorkerWrapper vllm.v1.worker.worker_base.WorkerWrapperBase + Actor monarch_executor.py Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.
ForgeMonarchExecutor (next diff) MonarchExecutor forge_executor.py Extends executor with TorchStore Controller handling for weight updates.
ForgeWorkerWrapper (next diff) WorkerWrapper forge_executor.py Extends worker with TorchStore weight loading capabilities.
Generator ForgeActor generator.py Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.

MonarchExecutor and WorkerWrapper are designed to be upstreamed to vLLM alongside the existing RayDistributedExecutor, enabling Monarch as a first-class distributed backend.

Decision 3: Executor-Owns-Workers Pattern

The architecture follows vLLM's Ray executor pattern where:

  • Caller (Generator) owns HostMesh: Resource allocation (hosts, GPU IDs)
  • Executor owns ProcMesh + Workers: Execution lifecycle
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘

Design: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh

WorkerRegistry bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

Executor Cleanup Responsibility:
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:

  1. MonarchExecutor.shutdown() destroys process groups on workers (prevents NCCL errors)
  2. Stops proc_mesh
  3. Generator.shutdown() only needs to stop generator_proc

Limitations

  • TP: Supported (single-node and multi-node)
  • PP: NOT supported (would require DAG-like execution pattern)
  • Shared memory cache (mm_processor_cache_type='shm') not supported
  • Symmetric memory all-reduce disabled (VLLM_ALLREDUCE_USE_SYMM_MEM=0)

Test Plan

[-] Resource / Lifecycle: pytest tests/integration_tests/test_generator_lifecycle.py -v -s
[-] Single node TP local benchmark throughput test: python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5 to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 17, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Jan 17, 2026

@JenniferWang has exported this pull request. If you are a Meta employee, you can view the originating Diff in D90280578.

Copy link
Copy Markdown
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review automatically exported from Phabricator review in Meta.

facebook-github-bot pushed a commit that referenced this pull request Jan 17, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor-Owns-Workers Pattern

The architecture follows vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
@JenniferWang JenniferWang linked an issue Jan 17, 2026 that may be closed by this pull request
2 tasks
facebook-github-bot pushed a commit that referenced this pull request Jan 20, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
facebook-github-bot pushed a commit that referenced this pull request Jan 22, 2026
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
Summary:

## Summary

This diff introduces vLLM v1 integration for forge & Monarch that works for version > 0.13.0. 

Functionality wise, this diff implements:
  - Single-node TP (unoptimized, TCP-based proc communication)
  - Multi-node TP (same TCP mechanism)

Pending work (next diff stack): First focus on Single-node TP
  - Unix socket-based communication (instead of TCP)
  - Weight sync integration
  - Logging integration

After that, we can introduce Pipeline Parallelism:
  - Extend executor to capture stage graph (DAG-like execution pattern)

## Decisions 1: Integration Layer -- `AsyncLLM`

We integrate at the AsyncLLM layer (https://blog.vllm.ai/2025/01/27/v1-alpha-release.html), which sits higher in the stack compared to our v0 approach that disassembled EngineCore and integrated at the Worker level. We pick this layer for these main reasons
1. Reduced maintenance cost: vLLM v1 refactored internals significantly (new EngineCore, Scheduler, KVCacheManager). Integrating at AsyncLLM isolates us from these changes -- we only need to implement the Executor interface, not patch internal scheduling or memory management.
2. Better fit for agentic RL: The offline LLM class batches requests synchronously via `llm.generate([prompts])`. AsyncLLM exposes an async generator interface (async for output in llm.generate(prompt)) that supports streaming, priority scheduling, and concurrent request handling required for online RL rollouts.

## Decision 2: Extension Points -- Executor + WorkerWrapperBase
| **Class**              | **Base Class**                                         | **Location**            | **Purpose**                                                                          |
|------------------------|--------------------------------------------------------|-------------------------|--------------------------------------------------------------------------------------|
| MonarchExecutor        | vllm.v1.executor.abstract.Executor                     | monarch_executor.py     | Creates ProcMesh from HostMesh, spawns workers, manages collective_rpc() dispatch.  |
| WorkerWrapper          | vllm.v1.worker.worker_base.WorkerWrapperBase + Actor   | monarch_executor.py     | Dual-inheritance wrapper exposing vLLM worker methods as Monarch endpoints.         |
| ForgeMonarchExecutor (next diff)   | MonarchExecutor                                        | forge_executor.py       | Extends executor with TorchStore Controller handling for weight updates.            |
| ForgeWorkerWrapper (next diff)     | WorkerWrapper                                          | forge_executor.py       | Extends worker with TorchStore weight loading capabilities.                         |
| Generator              | ForgeActor                                             | generator.py            | Forge-specific orchestration: provisions hosts, allocates GPUs, manages AsyncLLM.   |

**`MonarchExecutor` and `WorkerWrapper` are designed to be upstreamed to vLLM alongside the existing `RayDistributedExecutor`, enabling Monarch as a first-class distributed backend.**


## Decision 3: Executor Owns Workers Lifecycle

The architecture aligns closer with vLLM's Ray executor pattern where:
- **Caller (Generator) owns HostMesh**: Resource allocation (hosts, GPU IDs)
- **Executor owns ProcMesh + Workers**: Execution lifecycle

```
    ┌───────────────────────────────────────────────────────────────────────┐
    │                              Host Mesh                                │
    │                                                                       │
    │  ┌─────────────────────────────────────────────────────────────────┐  │
    │  │ Caller process                                                  │  │
    │  │                                                                 │  │
    │  │  ┌─────────────────────┐       ┌─────────────────────────────┐  │  │
    │  │  │ AsyncLLM            │       │ WorkerRegistry (actor)      │  │  │
    │  │  └─────────────────────┘       └─────────────────────────────┘  │  │
    │  │            │                                                    │  │
    │  │            │ serialize host_mesh & registry to env vars         │  │
    │  │            ▼                                                    │  │
    │  │  ┌───────────────────────────────────────────────────────────┐  │  │
    │  │  │ EngineCore subprocess                                     │  │  │
    │  │  │                                                           │  │  │
    │  │  │ MonarchExecutor                                           │  │  │
    │  │  │   ├── deserialize host_mesh                               │  │  │
    │  │  │   ├── create proc_mesh from host_mesh (owns lifecycle) ───│──│──│──┐
    │  │  │   ├── spawn worker actors on proc_mesh                    │  │  │  │
    │  │  │   └── register workers in WorkerRegistry                  │  │  │  │
    │  │  └───────────────────────────────────────────────────────────┘  │  │  │
    │  └─────────────────────────────────────────────────────────────────┘  │  │
    │                                                                       │  │
    │  ┌─────────────────────────────────────────────────────────────────┐  │  │
    │  │ GPU ProcMesh (owned by MonarchExecutor)                         │  │  │
    │  │                                                                 │  │  │
    │  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐         │  │  │
    │  │  │ Worker 0 │  │ Worker 1 │  │ Worker 2 │  │ Worker 3 │  ... ◀──│──│──┘
    │  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘         │  │
    │  │                   ◀──── NCCL (tensor parallel) ────▶            │  │
    │  └─────────────────────────────────────────────────────────────────┘  │
    └───────────────────────────────────────────────────────────────────────┘
```

**Design**: Caller owns host_mesh (resource allocation), executor owns proc_mesh + workers (execution). This mirrors vLLM's Ray executor pattern. Since we want to collocate Generator Actor with the worker host mesh, it's easier to stick to caller owning host mesh 

**WorkerRegistry** bridges the process boundary -- MonarchExecutor (in subprocess) registers workers there, Generator queries it after AsyncLLM initialization.

**Executor Cleanup Responsibility**: 
Since MonarchExecutor creates proc_mesh from host_mesh, it owns the cleanup:
1. `MonarchExecutor.shutdown()` destroys process groups on workers (prevents NCCL errors)
2. Stops proc_mesh
3. `Generator.shutdown()` only needs to stop generator_proc


## Limitations

- **TP**: Supported (single-node and multi-node)
- **PP**: NOT supported (would require DAG-like execution pattern)
- Shared memory cache (`mm_processor_cache_type='shm'`) not supported
- Symmetric memory all-reduce disabled (`VLLM_ALLREDUCE_USE_SYMM_MEM=0`)

## Test Plan
[-] Resource / Lifecycle: `pytest tests/integration_tests/test_generator_lifecycle.py -v -s`
[-] Single node TP local benchmark throughput test: `python -m benchmarks.generator.throughput --config apps/grpo/qwen3_1_7b.yaml benchmark.num_requests=10 benchmark.dataset=fixed benchmark.fixed_prompt="Tell me a joke" benchmark.num_samples=5` to verify the vllm instantiation on local host.
[-] Single node TP MAST benchmark throughput test to verify vllm instantiation on remote host:  https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_mast-eh7o6d%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D
[-] Multi-node (TP) MAST benchmark throughput test: https://www.internalfb.com/msl/studio/runs/mast/qwen3_1_7b_multinode_test-gr8aes%3APRODUCTION%3A0/logs?attempt=0&taskGroups=client%3A0&statusFilter=PENDING%2CRUNNING%2CCOMPLETE%2CFAILED%2CABANDONED%2CSTOPPING&logarithm=%7B%22after%22%3A10%2C%22before%22%3A20%7D

Reviewed By: allenwang28

Differential Revision: D90280578
@JenniferWang JenniferWang merged commit 759ab71 into main Jan 22, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[vLLM v0.13] Re-architect forge's integration with vLLM (generator.py)

2 participants