Skip to content

Commit d3eb3bf

Browse files
authored
docs: clarify RL training GPU requirements by model config (#764)
1 parent 3b233c1 commit d3eb3bf

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ pixi run install
7979

8080
> **Note:** We are actively working on enabling pure `uv` installation. Currently, Conda is the recommended approach. `uv` support is not fully working at the moment but is being tracked in [issue #494](https://github.com/meta-pytorch/torchforge/issues/494).
8181
82-
After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices):
82+
After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 2 GPU devices):
8383

8484
```
8585
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

apps/grpo/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ Since each can covers 2 square meters, we need to divide the total wall area by
4040

4141
## Quick Start
4242

43-
**Llama 3.1 8B** (recommended for learning, requires 5 GPUs as is, not optimized):
43+
**Llama 3.1 8B** (recommended for learning, requires 4 GPUs as is, not optimized):
4444
```bash
4545
python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml
4646
```
4747

48-
**Qwen3 1.7B** (NOTE: Qwen3 is already saturated on GSM8K, so rewards will **not** increase. Requires 3 GPUs, not optimized):
48+
**Qwen3 1.7B** (NOTE: Qwen3 is already saturated on GSM8K, so rewards will **not** increase. Requires 2 GPUs, not optimized):
4949
```bash
5050
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
5151
```

docs/source/getting_started.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Before installing TorchForge, ensure your system meets the following requirement
1111
| **Operating System** | Linux (Fedora/Ubuntu/Debian) | MacOS and Windows not currently supported |
1212
| **Python** | 3.10 or higher | Python 3.11 recommended |
1313
| **GPU** | NVIDIA with CUDA support | AMD GPUs not currently supported |
14-
| **Minimum GPUs** | 2+ for SFT, 3+ for GRPO | More GPUs enable larger models |
14+
| **Minimum GPUs** | 2+ for SFT; 2+ for GRPO | More GPUs enable training larger models; GRPO with KL (`beta > 0`) requires a reference model and increases the GPU requirement. |
1515
| **CUDA** | 12.8 | Required for GPU training |
1616
| **RAM** | 32GB+ recommended | Depends on model size |
1717
| **Disk Space** | 50GB+ free | For models, datasets, and checkpoints |
@@ -150,7 +150,7 @@ hf download meta-llama/Meta-Llama-3.1-8B-Instruct --local-dir /tmp/Meta-Llama-3.
150150
uv run forge run --nproc_per_node 2 \
151151
apps/sft/main.py --config apps/sft/llama3_8b.yaml
152152

153-
# Run GRPO training (requires 3+ GPUs)
153+
# Run GRPO training (requires 2+ GPUs)
154154
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
155155
```
156156

@@ -181,16 +181,15 @@ Fine-tune Llama 3 8B on your data. **Requires: 2+ GPUs**
181181

182182
### Example 2: GRPO Training
183183

184-
Train a model using reinforcement learning with GRPO. **Requires: 3+ GPUs**
184+
Train a model using reinforcement learning with GRPO. **Requires: 2+ GPUs**
185185

186186
```bash
187187
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
188188
```
189189

190190
**What's Happening:**
191191
- GPU 0: Trainer model (being trained, powered by TorchTitan)
192-
- GPU 1: Reference model (frozen baseline, powered by TorchTitan)
193-
- GPU 2: Policy model (scoring outputs, powered by vLLM)
192+
- GPU 1: Policy model (scoring outputs, powered by vLLM)
194193
- **Monarch** orchestrates all three components
195194
- **TorchStore** handles weight synchronization from training to inference
196195

docs/source/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ Before starting significant work, signal your intention in the issue tracker to
185185
[Monarch](https://meta-pytorch.org/monarch), [vLLM](https://docs.vllm.ai/en/latest/),
186186
and [TorchTitan](https://github.com/pytorch/torchtitan).
187187
* **Multi-GPU Support**: Designed for distributed training
188-
with minimum 3 GPU requirement for GRPO training
188+
with minimum 2 GPU requirement for GRPO training
189189
* **Model Support**: Includes pre-configured setups for popular models
190190
like Llama3 8B and Qwen3.1 7B
191191

0 commit comments

Comments
 (0)