Skip to content

Commit a3ae18b

Browse files
HosseinKaviani-HHossein Kavianihamedanifelipemello1
authored
Train batch generic (#724)
Co-authored-by: Hossein Kavianihamedani <hosseinkh@fb.com> Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
1 parent bd40c5e commit a3ae18b

File tree

6 files changed

+92
-63
lines changed

6 files changed

+92
-63
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ async def continuous_training():
332332
else:
333333
t.step("waiting_for_buffer")
334334

335-
inputs, targets = batch
336-
await trainer.train_step.call(inputs, targets)
335+
await trainer.train_step.call(batch)
337336
training_step += 1
338337
t.step("train_step")
339338

src/forge/actors/trainer/titan.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from forge.observability.metrics import record_metric, Reduce
2121
from forge.observability.perf_tracker import Tracer
2222
from forge.rl.loss import create_shifted_targets
23+
from forge.types import TrainBatch
2324
from monarch.actor import endpoint
2425
from torch import Tensor
2526
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
@@ -117,17 +118,15 @@ async def setup(self):
117118
self.engine.checkpointer.load(step=self.step)
118119
self.engine.optimizers.zero_grad()
119120

120-
def forward_backward(
121-
self, inputs: dict[str, Tensor], targets: dict[str, Tensor]
122-
) -> Tensor:
121+
def forward_backward(self, batch: TrainBatch) -> Tensor:
123122
model_parts = self.engine.model_parts
124123
parallel_dims = self.engine.parallel_dims
125124
optional_context_parallel_ctx = None
126125

127126
# Create shifted target_ids for next-token prediction
128127
# target_ids[i] = input_ids[i+1], with loss_mask applied
129-
targets["target_ids"] = create_shifted_targets(
130-
inputs["tokens"], targets.get("loss_mask")
128+
batch.loss_inputs["target_ids"] = create_shifted_targets(
129+
batch.model_inputs["tokens"], batch.loss_inputs.get("loss_mask")
131130
)
132131

133132
if parallel_dims.pp_enabled:
@@ -136,8 +135,8 @@ def forward_backward(
136135
with self.engine.train_context(optional_context_parallel_ctx):
137136
assert len(model_parts) == 1
138137
with self.engine.maybe_enable_amp:
139-
logits = model_parts[0](**inputs)
140-
loss_output = self.loss(logits, **targets)
138+
logits = model_parts[0](**batch.model_inputs)
139+
loss_output = self.loss(logits, **batch.loss_inputs)
141140
loss = loss_output.loss
142141

143142
# Record metrics from loss output
@@ -156,19 +155,16 @@ def forward_backward(
156155
return loss
157156

158157
@endpoint
159-
async def train_step(
160-
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
161-
) -> float:
158+
async def train_step(self, batches: list[TrainBatch]) -> float:
162159
t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True)
163160
t.start()
164161

165162
self.engine.gc_handler.run(self.step)
166-
local_inputs = inputs[self.engine.dp_rank]
167-
local_targets = targets[self.engine.dp_rank]
168-
batch_to_device(local_inputs, self.engine.device)
169-
batch_to_device(local_targets, self.engine.device)
163+
batch = batches[self.engine.dp_rank]
164+
batch_to_device(batch.model_inputs, self.engine.device)
165+
batch_to_device(batch.loss_inputs, self.engine.device)
170166

171-
loss = self.forward_backward(local_inputs, local_targets)
167+
loss = self.forward_backward(batch)
172168
torch.distributed.all_reduce(loss)
173169

174170
t.step("forward_backward")

src/forge/rl/collate.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any
8-
97
import torch
108
from forge.rl.types import Group
9+
from forge.types import TrainBatch
1110

1211

13-
def collate(
14-
batches: list[Group],
15-
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
12+
def collate(batches: list[Group]) -> list[TrainBatch]:
1613
"""
17-
Collates a list of batches into a single batch of inputs and targets.
14+
Collates a list of batches into TrainBatch objects.
1815
Each batch is a list of episodes, and each episode is a dict of tensors.
1916
"""
20-
inputs = []
21-
targets = []
17+
result = []
2218
for batch in batches:
2319
request = [e.request_tensor for e in batch]
2420
request = torch.stack(request) # [b x s]
@@ -41,14 +37,18 @@ def collate(
4137
generator_logprobs = torch.stack([e.generator_logprobs for e in batch])
4238
loss_mask = torch.stack([e.loss_mask for e in batch])
4339

44-
input = {"tokens": input_ids}
45-
target = {
40+
loss_inputs = {
4641
"generator_logprobs": generator_logprobs,
4742
"loss_mask": loss_mask,
4843
"advantages": advantages,
4944
}
5045
if ref_logprobs is not None:
51-
target["ref_logprobs"] = ref_logprobs
52-
inputs.append(input)
53-
targets.append(target)
54-
return inputs, targets
46+
loss_inputs["ref_logprobs"] = ref_logprobs
47+
48+
result.append(
49+
TrainBatch(
50+
model_inputs={"tokens": input_ids},
51+
loss_inputs=loss_inputs,
52+
)
53+
)
54+
return result

src/forge/types.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,34 @@ class ProvisionerConfig:
126126
"""A config for the forge provisioner."""
127127

128128
launcher_config: LauncherConfig
129+
130+
131+
@dataclass
132+
class TrainBatch:
133+
"""Universal training batch for all Forge training modes.
134+
135+
Usage:
136+
logits = model(**batch.model_inputs)
137+
loss = loss_fn(logits, **batch.loss_inputs)
138+
139+
Attributes:
140+
model_inputs (dict[str, Any]): Inputs for model forward pass (e.g., input_ids, attention_mask).
141+
loss_inputs (dict[str, Any]): Inputs for loss computation (e.g., target_ids, advantages, beta).
142+
meta (dict[str, Any]): Any extra metadata that is not a model or loss input.
143+
144+
Example:
145+
>>> # SFT
146+
>>> batch = TrainBatch(
147+
>>> model_inputs={"input_ids": ids, "attention_mask": mask},
148+
>>> loss_inputs={"target_ids": targets},
149+
>>> )
150+
>>> # RL (GRPO)
151+
>>> batch = TrainBatch(
152+
>>> model_inputs={"input_ids": ids},
153+
>>> loss_inputs={"target_ids": targets, "advantages": adv, "ref_logprobs": ref},
154+
>>> )
155+
"""
156+
157+
model_inputs: dict[str, Any]
158+
loss_inputs: dict[str, Any]
159+
meta: dict[str, Any] = field(default_factory=dict)

tests/sandbox/rl_trainer/main.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ProcessConfig,
2222
ProvisionerConfig,
2323
ServiceConfig,
24+
TrainBatch,
2425
)
2526
from forge.util.config import parse
2627
from omegaconf import DictConfig
@@ -75,13 +76,12 @@ def generate_random_batch(
7576
vocab_size: int = 32000,
7677
device: str = "cuda",
7778
dp_size: int = 1,
78-
):
79+
) -> list[TrainBatch]:
7980
"""
80-
Generate random input and target tensors matching GRPO data format
81-
Creates one batch per data parallel rank
81+
Generate random TrainBatch objects matching GRPO data format.
82+
Creates one batch per data parallel rank.
8283
"""
83-
inputs = []
84-
targets = []
84+
batches = []
8585

8686
# Create one batch for each data parallel rank
8787
for _ in range(dp_size):
@@ -109,17 +109,19 @@ def generate_random_batch(
109109
)
110110
advantages = torch.randn((local_batch_size, 1), device=device)
111111
input_tokens = torch.cat([request, response], dim=1)
112-
inputs.append({"tokens": input_tokens})
113-
targets.append(
114-
{
115-
"response": response,
116-
"ref_logprobs": ref_logprobs,
117-
"advantages": advantages,
118-
"padding_mask": padding_mask,
119-
}
112+
batches.append(
113+
TrainBatch(
114+
model_inputs={"tokens": input_tokens},
115+
loss_inputs={
116+
"response": response,
117+
"ref_logprobs": ref_logprobs,
118+
"advantages": advantages,
119+
"padding_mask": padding_mask,
120+
},
121+
)
120122
)
121123

122-
return inputs, targets
124+
return batches
123125

124126

125127
async def main(cfg: DictConfig):
@@ -201,7 +203,7 @@ async def continuous_training():
201203
t = Tracer("trainer/continuous_training")
202204
t.start()
203205

204-
inputs, targets = generate_random_batch(
206+
batches = generate_random_batch(
205207
local_batch_size=local_batch_size,
206208
request_len=request_len,
207209
response_len=response_len,
@@ -211,7 +213,7 @@ async def continuous_training():
211213
t.step("generate_random_data")
212214

213215
# Perform training step
214-
await trainer.train_step.call(inputs, targets)
216+
await trainer.train_step.call(batches)
215217
training_step += 1
216218
t.step("train_step")
217219

tests/sandbox/weight_sync/main.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from forge.actors.trainer import RLTrainer
2525
from forge.controller.provisioner import init_provisioner, shutdown
2626
from forge.observability.metric_actors import get_or_create_metric_logger
27-
from forge.types import LauncherConfig, ProvisionerConfig
27+
from forge.types import LauncherConfig, ProvisionerConfig, TrainBatch
2828
from forge.util.config import parse
2929
from omegaconf import DictConfig
3030
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -37,13 +37,12 @@ def generate_random_batch(
3737
vocab_size: int = 32000,
3838
device: str = "cuda",
3939
dp_size: int = 1,
40-
):
40+
) -> list[TrainBatch]:
4141
"""
42-
Generate random input and target tensors for a single training step.
42+
Generate random TrainBatch objects for a single training step.
4343
Creates one batch per data parallel rank.
4444
"""
45-
inputs = []
46-
targets = []
45+
batches = []
4746

4847
# Create one batch for each data parallel rank
4948
for _ in range(dp_size):
@@ -71,17 +70,19 @@ def generate_random_batch(
7170
)
7271
advantages = torch.randn((local_batch_size, 1), device=device)
7372
input_tokens = torch.cat([request, response], dim=1)
74-
inputs.append({"tokens": input_tokens})
75-
targets.append(
76-
{
77-
"response": response,
78-
"ref_logprobs": ref_logprobs,
79-
"advantages": advantages,
80-
"padding_mask": padding_mask,
81-
}
73+
batches.append(
74+
TrainBatch(
75+
model_inputs={"tokens": input_tokens},
76+
loss_inputs={
77+
"response": response,
78+
"ref_logprobs": ref_logprobs,
79+
"advantages": advantages,
80+
"padding_mask": padding_mask,
81+
},
82+
)
8283
)
8384

84-
return inputs, targets
85+
return batches
8586

8687

8788
async def main(cfg: DictConfig):
@@ -147,15 +148,15 @@ async def main(cfg: DictConfig):
147148
print("Running single training step...")
148149
step_start = time.time()
149150

150-
inputs, targets = generate_random_batch(
151+
batches = generate_random_batch(
151152
local_batch_size=local_batch_size,
152153
request_len=request_len,
153154
response_len=response_len,
154155
vocab_size=vocab_size,
155156
dp_size=dp_size,
156157
)
157158

158-
await trainer.train_step.call(inputs, targets)
159+
await trainer.train_step.call(batches)
159160
step_time = time.time() - step_start
160161
print(f"Finished train step in ({step_time:.2f}s)\n")
161162

0 commit comments

Comments
 (0)