-
Notifications
You must be signed in to change notification settings - Fork 98
Expand file tree
/
Copy pathqwen3_32b_experimental.yaml
More file actions
158 lines (147 loc) · 3.7 KB
/
qwen3_32b_experimental.yaml
File metadata and controls
158 lines (147 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# This is an experimental fork of qwen3_32b.yaml that enables the following:
# - shared memory based weight prefetching for weight updates
#
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.main --config apps/grpo/qwen32b.yaml
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability
# Global configuration
group_size: 16
local_batch_size: 32 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default
provisioner:
launcher: slurm
# Main loop configuration
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
console:
reduce_across_ranks: True
# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
model: ${model}
# Policy configuration
policy:
prefetch_weights: true
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: false
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
# Trainer configuration
trainer:
model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: full
# Replay buffer configuration
replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 1
# Reference model configuration
ref_model:
model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}
training:
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
# All resource allocations
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 4
hosts: 1
with_gpus: true
mesh_name: policy
ref_model:
procs: ${ref_model.parallelism.tensor_parallel_degree}
num_replicas: 1
with_gpus: true
mesh_name: ref_model
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
mesh_name: reward_actor
actors:
dataset:
procs: 1
with_gpus: false
mesh_name: dataset
trainer:
procs: 8
hosts: 1
with_gpus: true
mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
mesh_name: compute_advantages