forked from meta-pytorch/torchforge
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrainer.py
More file actions
459 lines (396 loc) · 16.9 KB
/
trainer.py
File metadata and controls
459 lines (396 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import shutil
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
from typing import Callable
import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts
from monarch.actor import endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Comm,
Compile,
Job,
LRScheduler,
MemoryEstimation,
Model,
Optimizer,
Parallelism,
Quantize,
Training,
)
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
rdma_available,
)
from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def cleanup_old_weight_versions(
state_dict_key: str,
delim: str,
current_policy_version: int,
) -> None:
"""Delete old weight versions, keeping only current and N-1 versions.
TODO - issues/194: provide a more robust way to handle eviction.
Args:
state_dict_key: The base key for state dict storage
delim: The delimiter used between key and version
current_policy_version: The current policy version to keep
"""
if current_policy_version <= 1:
return # No cleanup needed for versions 0 or 1
prefix = f"{state_dict_key}{delim}"
current_weights = f"{prefix}{current_policy_version}"
previous_weights = f"{prefix}{current_policy_version - 1}"
# Find all weight directories that match our pattern
parent_dir = os.path.dirname(prefix) or "."
if os.path.exists(parent_dir):
for item in os.listdir(parent_dir):
item_path = os.path.join(parent_dir, item)
if (
item.startswith(os.path.basename(prefix))
and item != os.path.basename(current_weights)
and item != os.path.basename(previous_weights)
and os.path.isdir(item_path)
):
try:
shutil.rmtree(item_path, ignore_errors=True)
logger.debug(f"Removed old weights at {item_path}")
except OSError as e:
logger.debug(f"Error deleting {item_path}: {e}")
@dataclass
class RLTrainer(ForgeActor):
"""A reinforcement learning trainer actor for policy optimization training.
Built on top of TorchTitan's training engine, this actor provides a complete training
loop for reinforcement learning. It performs forward and backward passes with gradient
computation, optimization steps, and checkpoint management. Unlike the ReferenceModel
actor which only runs forward passes, RLTrainer actively updates the policy model
parameters through gradient descent.
The trainer supports the same distributed training strategies that TorchTitan does,
including but not limited to, tensor parallelism, data parallelism, and FSDP
(Fully Sharded Data Parallel). It is typically used in conjunction with ReferenceModel
for policy optimization algorithms like GRPO (Group Relative Policy Optimization),
where it optimizes the policy against a loss that includes KL divergence penalties
from the reference model.
The trainer handles:
- Forward and backward propagation with automatic mixed precision (AMP)
- Optimizer steps with learning rate scheduling
"""
job: Job = field(default_factory=Job)
model: Model = field(default_factory=Model)
optimizer: Optimizer = field(default_factory=Optimizer)
lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
training: Training = field(default_factory=Training)
parallelism: Parallelism = field(default_factory=Parallelism)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
compile: Compile = field(default_factory=Compile)
quantize: Quantize = field(default_factory=Quantize)
comm: Comm = field(default_factory=Comm)
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
# Non JobConfig-related fields
loss: Callable = lambda logits, **targets: logits
state_dict_key: str = "model_state_dict"
use_dcp: bool = not rdma_available()
dcp_path: str = "forge_dcp_tmp"
def __post_init__(self):
"""Initializes config types and env variables.
torchrun normally hands env variables, but we need to do it ourselves
in monarch for now.
"""
super().__init__()
if self.use_dcp:
# DCP specific optimization
torch.serialization.set_crc32_options(False)
# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
if isinstance(attr, Mapping):
setattr(self, f.name, f.type(**attr))
elif not isinstance(attr, f.type):
raise TypeError(
f"{f.name} should be a {f.type} type or a dict like object"
)
self.step = 1 # fragile contract.
self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
logger.info("Compiling loss")
self.loss = torch.compile(self.loss)
@endpoint
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
for key in {
"loss",
"state_dict_key",
"use_dcp",
"dcp_path",
}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
self.engine.optimizers.zero_grad()
def forward_backward(
self, inputs: dict[str, Tensor], targets: dict[str, Tensor]
) -> Tensor:
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
# if getattr(self.engine.model_args, "use_flex_attn", False):
# cp_mesh = (
# parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
# )
# init_attention_mask(
# inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
# )
# optional_context_parallel_ctx = (
# dist_utils.create_context_parallel_ctx(
# cp_mesh=parallel_dims.world_mesh["cp"],
# cp_buffers=[inputs, targets] + [m.freqs_cis for m in model_parts],
# cp_seq_dims=[1, 1] + [0 for _ in model_parts],
# cp_no_restore_buffers={inputs, targets},
# cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
# )
# if parallel_dims.cp_enabled
# else None
# )
optional_context_parallel_ctx = None
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet")
# TODO implement PP
# # Pipeline Parallel forward / backward inside step() call
# with self.train_context(optional_context_parallel_ctx):
# targets, losses = (
# (labels, []) if self.pp_has_last_stage else (None, None)
# )
# if self.pp_has_first_stage:
# self.pp_schedule.step(
# inputs, target=targets, losses=losses, input_batch=inputs
# )
# else:
# self.pp_schedule.step(
# target=targets, losses=losses, input_batch=inputs
# )
#
# # accumulate losses across pipeline microbatches
# # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
# loss = (
# torch.mean(torch.stack(losses)).to(self.device)
# if self.pp_has_last_stage
# else torch.tensor([-1.0], device=self.device)
# )
else:
# Non-PP forward / backward
with self.engine.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.engine.maybe_enable_amp:
logits = model_parts[0](**inputs)
loss = self.loss(logits, **targets)
# need to free to before bwd to avoid peaking memory
del logits
loss.backward()
return loss
@endpoint
async def train_step(
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
) -> float:
# Log timesteps
t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True)
t.start()
self.engine.gc_handler.run(self.step)
local_inputs = inputs[self.engine.dp_rank]
local_targets = targets[self.engine.dp_rank]
batch_to_device(local_inputs, self.engine.device)
batch_to_device(local_targets, self.engine.device)
# compute policy logprobs
# TODO implement gradient accumulation
# with GradientAccumulation(
# self.gradient_accumulation_steps,
# self.model,
# self.data_parallel_size,
# ) as grad_acc:
loss = self.forward_backward(local_inputs, local_targets)
torch.distributed.all_reduce(loss)
t.step("forward_backward")
# Get learning rate from scheduler
current_lr = (
self.engine.lr_schedulers.get_last_lr()[0]
if hasattr(self.engine.lr_schedulers, "get_last_lr")
else 0.001
)
record_metric("rl_trainer/learning_rate", current_lr, Reduce.MIN)
self.engine.optimizers.step()
self.engine.optimizers.zero_grad()
self.engine.lr_schedulers.step()
t.step("optimizer_step")
# Record training metrics
# TODO: delete item() to avoid cpu-gpu sync
loss = loss.detach().cpu().item()
record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM)
record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN)
# TODO: Extract actual KL divergence and policy entropy from the loss computation
# These are placeholder values until the loss function exposes these metrics
# record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
# record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD)
# record_metric("rl_trainer/step/avg_policy_entropy", 0.0, Reduce.MEAN)
self.step += 1
self.engine.checkpointer.save(
curr_step=self.step,
last_step=self.step == self.num_training_steps,
)
t.step("save_checkpoint")
t.stop()
return loss
@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
t = Tracer("rl_trainer_perf/push_weights", timer="gpu", track_memory=True)
t.start()
logger.info(f"Pushing weights for policy version {policy_version}")
start_time = time.perf_counter()
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")
sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)
t.step("flatten_state_dict")
if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
t.step("to_hf")
if self.use_dcp:
key = get_dcp_whole_state_dict_key(policy_version)
dcp_id = f"{self.dcp_path}/{key}"
storage_writer = torch.distributed.checkpoint.FileSystemWriter(
dcp_id, single_file_per_rank=False, thread_count=8
)
metadata = dcp.save(storage_writer=storage_writer, state_dict=hf_state_dict)
dcp_handle = DcpHandle(
checkpoint_id=dcp_id,
metadata=metadata,
param_names=hf_state_dict.keys(),
)
await ts.put(key, dcp_handle)
t.step("dcp_save")
else:
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)
t.step("ts_save")
t.stop()
end_time = time.perf_counter()
logger.info("Completed weights push in %.2f seconds", end_time - start_time)
@endpoint
async def cleanup(self) -> None:
if self.engine.checkpointer:
self.engine.checkpointer.close()
def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor:
"""Shard and concatenate tensors along a given dimension.
Args:
source (list[torch.Tensor]): List of tensors to shard and concatenate.
dim (int): Dimension along which to shard and concatenate.
tp (int): Number of tensor parallel groups.
Returns:
torch.Tensor: Concatenated tensor.
"""
sharded_sources = []
for source in sources:
sharded_sources.append(torch.chunk(source, tp, dim=dim))
combined_shards = []
for shard_idx in range(tp):
combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim)
combined_shards.append(combined)
return torch.cat(combined_shards, dim=dim)
def _qwen3_hf_to_vllm(
sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int
) -> dict[str, torch.Tensor]:
"""Convert transformers state dict to vLLM format. Specifically, this fuses
QKV projection and MLP gate_up_proj layers.
Args:
sd (dict): State dict from HF model.
num_layers (int): Number of layers in the model.
Returns:
dict: State dict in vLLM format.
"""
load_sd = {}
def unwrap(t):
"""Unwrap a DTensor to a Tensor."""
return t.full_tensor() if isinstance(t, torch.distributed.tensor.DTensor) else t
for key in sd.keys():
sd[key] = unwrap(sd[key]).cpu()
# Copy over directly mapped keys
for k in sd:
if any(
x in k
for x in [
"down_proj",
"input_layernorm",
"post_attention_layernorm",
"o_proj",
"norm.weight",
"embed_tokens.weight",
"lm_head.weight",
]
):
load_sd[k] = sd[k]
for i in range(num_layers):
prefix = f"model.layers.{i}."
# QKV fusion
q = sd[prefix + "self_attn.q_proj.weight"]
k = sd[prefix + "self_attn.k_proj.weight"]
v = sd[prefix + "self_attn.v_proj.weight"]
load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat(
[q, k, v], dim=0, tp=vllm_tp
)
# Untested: QKV fusion - handle bias if present
q_bias_key = prefix + "self_attn.q_proj.bias"
k_bias_key = prefix + "self_attn.k_proj.bias"
v_bias_key = prefix + "self_attn.v_proj.bias"
if all(key in sd for key in [q_bias_key, k_bias_key, v_bias_key]):
q_bias = sd[q_bias_key]
k_bias = sd[k_bias_key]
v_bias = sd[v_bias_key]
load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat(
[q_bias, k_bias, v_bias], dim=0, tp=vllm_tp
)
# MLP gate_up_proj fusion
gate = sd[prefix + "mlp.gate_proj.weight"]
up = sd[prefix + "mlp.up_proj.weight"]
load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat(
[gate, up], dim=0, tp=vllm_tp
)
# Untested: MLP gate_up_proj fusion - handle bias if present
gate_bias_key = prefix + "mlp.gate_proj.bias"
up_bias_key = prefix + "mlp.up_proj.bias"
if all(key in sd for key in [gate_bias_key, up_bias_key]):
gate_bias = sd[gate_bias_key]
up_bias = sd[up_bias_key]
# Same sharding has to happen here
load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat(
[gate_bias, up_bias], dim=0, tp=vllm_tp
)
return load_sd