Skip to content

Commit 0754eff

Browse files
felipemello1Felipe Mello
andauthored
add llama 8b example (#623)
Co-authored-by: Felipe Mello <felipemello@fb.com>
1 parent 700b2f5 commit 0754eff

File tree

4 files changed

+219
-21
lines changed

4 files changed

+219
-21
lines changed

apps/grpo/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
Last updated: 2025-12-05
2+
3+
# GRPO on GSM8K
4+
5+
Training GRPO (Grouped Relative Policy Optimization) on GSM8K grade school math word problems.
6+
7+
## Dataset
8+
9+
GSM8K consists of grade school math word problems that require multi-step arithmetic reasoning. Models generate solutions with chain-of-thought reasoning and provide final answers in `\boxed{}` format.
10+
11+
**Example Input:**
12+
```
13+
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
14+
15+
Cutting Knowledge Date: December 2023
16+
Today Date: 26 Jul 2024
17+
18+
Put all your scratchpad work between <think> and </think> tags.
19+
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.<|eot_id|><|start_header_id|>user<|end_header_id|>
20+
21+
Lucille is painting her room. Two of her walls are 3 meters wide and 2 meters tall. The third wall is 5 meters wide and 2 meters tall. The final wall is 4 meters wide and 2 meters tall. If each can of paint covers 2 square meters, how many cans of paint does Lucille need?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
22+
```
23+
24+
**Example Output:**
25+
```
26+
<think>
27+
First, let's calculate the area of each wall:
28+
29+
- Two walls of 3 meters x 2 meters = 3 x 2 = 6 square meters per wall. Since there are two of these, 6 * 2 = 12 square meters.
30+
- The third wall is 5 meters x 2 meters = 10 square meters.
31+
- The final wall is 4 meters x 2 meters = 8 square meters.
32+
Total wall area = 12 + 10 + 8 = 30 square meters.
33+
34+
Since each can covers 2 square meters, we need to divide the total wall area by the area one can covers: 30 / 2 = 15.
35+
36+
</think>
37+
38+
<answer>15</answer>
39+
```
40+
41+
## Quick Start
42+
43+
**Llama 3.1 8B** (recommended for learning, requires 5 GPUs as is, not optimized):
44+
```bash
45+
python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml
46+
```
47+
48+
**Qwen3 1.7B** (NOTE: Qwen3 is already saturated on GSM8K, so rewards will **not** increase. Requires 3 GPUs, not optimized):
49+
```bash
50+
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
51+
```
52+
53+
## Expected Results
54+
55+
For **Llama 3.1 8B**, training rewards should rise above 0.8 within the first few steps as the model learns the task.
56+
57+
![Llama 3.1 8B Training Rewards](wandb_llama8b.png)
58+
59+
## Configurations
60+
61+
- `llama3_8b.yaml` - Meta Llama 3.1 8B Instruct
62+
- `qwen3_1_7b.yaml` - Qwen3 1.7B
63+
- `qwen3_8b.yaml` - Qwen3 8B
64+
- `qwen3_32b.yaml` - Qwen3 32B

apps/grpo/llama3_8b.yaml

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Grouped Relative Policy Optimization (GRPO)
2+
# >>> python -m apps.grpo.main --config apps/grpo/llama3_8b.yaml
3+
4+
# Global configuration
5+
group_size: 4
6+
local_batch_size: 4 # per-device batch size
7+
max_req_tokens: 1024
8+
max_res_tokens: 2048
9+
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
10+
off_by_n: 1 # Off by one by default
11+
12+
# Observability configuration
13+
metric_logging:
14+
wandb:
15+
project: grpo-training
16+
group: grpo_exp_${oc.env:USER}
17+
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
18+
console:
19+
logging_mode: global_reduce
20+
21+
# Dataset configuration
22+
dataset:
23+
path: "openai/gsm8k"
24+
revision: "main"
25+
data_split: "train"
26+
streaming: true
27+
model: ${model}
28+
29+
# Policy configuration
30+
policy:
31+
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
32+
model: ${model}
33+
tensor_parallel_size: 2
34+
pipeline_parallel_size: 1
35+
enforce_eager: false
36+
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
37+
n: ${group_size}
38+
max_tokens: ${max_res_tokens}
39+
temperature: 1.0
40+
top_p: 1.0
41+
42+
# Trainer configuration
43+
trainer:
44+
model:
45+
name: llama3
46+
flavor: 8B
47+
hf_assets_path: hf://${model}
48+
optimizer:
49+
name: AdamW
50+
lr: 1e-5
51+
eps: 1e-8
52+
lr_scheduler:
53+
warmup_steps: 1
54+
training:
55+
local_batch_size: ${local_batch_size}
56+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
57+
max_norm: 1.0
58+
steps: 1000000
59+
dtype: bfloat16
60+
gc_freq: 1
61+
compile:
62+
enable: false
63+
parallelism:
64+
data_parallel_replicate_degree: 1
65+
data_parallel_shard_degree: -1
66+
tensor_parallel_degree: 1
67+
pipeline_parallel_degree: 1
68+
context_parallel_degree: 1
69+
expert_parallel_degree: 1
70+
disable_loss_parallel: true
71+
checkpoint:
72+
enable: true
73+
folder: ./checkpoint # The folder to save checkpoints to.
74+
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
75+
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
76+
last_save_in_hf: true
77+
interval: 500
78+
async_mode: "disabled"
79+
activation_checkpoint:
80+
mode: selective
81+
selective_ac_option: op
82+
83+
# Replay buffer configuration
84+
replay_buffer:
85+
batch_size: ${local_batch_size}
86+
max_policy_age: ${off_by_n}
87+
# This should match the dp_size of TorchTitan
88+
# Here it's set explicitly to 2, because we've set
89+
# 2 GPUs for the trainer and we're using full FSDP.
90+
dp_size: 2
91+
92+
# Reference model configuration
93+
ref_model:
94+
model:
95+
name: llama3
96+
flavor: 8B
97+
hf_assets_path: hf://${model}
98+
training:
99+
seq_len: ${trainer.training.seq_len}
100+
dtype: bfloat16
101+
gc_freq: 1
102+
compile:
103+
enable: false
104+
parallelism:
105+
data_parallel_replicate_degree: 1
106+
data_parallel_shard_degree: 1
107+
tensor_parallel_degree: 1
108+
pipeline_parallel_degree: 1
109+
context_parallel_degree: 1
110+
expert_parallel_degree: 1
111+
checkpoint:
112+
initial_load_path: hf://${model}
113+
initial_load_in_hf: true
114+
115+
# All resource allocations
116+
services:
117+
policy:
118+
procs: ${policy.engine_args.tensor_parallel_size}
119+
num_replicas: 1
120+
with_gpus: true
121+
mesh_name: policy
122+
ref_model:
123+
procs: 1
124+
num_replicas: 1
125+
with_gpus: true
126+
mesh_name: ref_model
127+
reward_actor:
128+
procs: 1
129+
num_replicas: 1
130+
with_gpus: false
131+
mesh_name: reward_actor
132+
133+
actors:
134+
dataset:
135+
procs: 1
136+
with_gpus: false
137+
mesh_name: dataset
138+
trainer:
139+
procs: 2
140+
with_gpus: true
141+
mesh_name: trainer
142+
replay_buffer:
143+
procs: 1
144+
with_gpus: false
145+
mesh_name: replay_buffer
146+
compute_advantages:
147+
procs: 1
148+
with_gpus: false
149+
mesh_name: compute_advantages

apps/grpo/main.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from forge.actors.trainer import TitanTrainer
2828
from forge.controller.actor import ForgeActor
2929
from forge.controller.provisioner import init_provisioner, shutdown
30-
from forge.data.rewards import LanguageReward, MathReward, ThinkingReward
30+
from forge.data.rewards import MathReward, ThinkingReward
3131
from forge.data_models.completion import Completion
3232
from forge.observability.metric_actors import get_or_create_metric_logger
3333
from forge.observability.metrics import record_metric, Reduce
@@ -274,15 +274,10 @@ async def setup(self):
274274
self._epoch = 0
275275

276276
def gsm8k_transform(sample):
277-
system_prompt = """You are a helpful AI assistant that solves math problems.
278-
279-
Please show your reasoning inside <思考></思考> tags, then provide your final numerical answer inside <answer></answer> tags.
280-
281-
Example:
282-
Question: What is 12 + 5?
283-
<思考>12と5を足します。12 + 5 = 17です。</思考>
284-
<answer>17</answer>
285-
"""
277+
system_prompt = """
278+
Put all your scratchpad work between <think> and </think> tags.
279+
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
280+
"""
286281
request: str = sample["question"]
287282
as_chat = [
288283
{"role": "system", "content": system_prompt},
@@ -409,17 +404,7 @@ async def main(cfg: DictConfig):
409404
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
410405
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
411406
RewardActor.options(**cfg.services.reward_actor).as_service(
412-
reward_functions=[
413-
MathReward(),
414-
ThinkingReward(tag="思考"), # Use Japanese tag
415-
LanguageReward(
416-
target_language="ja",
417-
tag="思考",
418-
match_reward=2.0,
419-
debug=False, # set to true for verbose logging
420-
debug_sample_rate=0.1,
421-
), # Japanese language reward with debug
422-
]
407+
reward_functions=[MathReward(), ThinkingReward()]
423408
),
424409
)
425410

apps/grpo/wandb_llama8b.png

166 KB
Loading

0 commit comments

Comments
 (0)