Skip to content

Commit 1351212

Browse files
felipemello1Felipe Mello
andauthored
Loss refactor (#699)
Co-authored-by: Felipe Mello <felipemello@fb.com>
1 parent b7c9c31 commit 1351212

34 files changed

+2096
-783
lines changed

apps/grpo/llama3_8b.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ generator:
3939
max_tokens: ${max_res_tokens}
4040
temperature: 1.0
4141
top_p: 1.0
42+
logprobs: 1 # returns log probabilities for sampled tokens
4243

4344
# Trainer configuration
4445
trainer:

apps/grpo/main.py

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,65 +24,16 @@
2424
from forge.observability.metrics import record_metric, Reduce
2525
from forge.observability.perf_tracker import Tracer
2626
from forge.rl import collate, ComputeAdvantages, Episode, RewardActor
27+
from forge.rl.loss import GRPOLoss
2728
from forge.types import LauncherConfig, ProvisionerConfig
2829
from forge.util.checkpoint import drop_weights
2930
from forge.util.config import parse
3031
from forge.util.logging import get_logger
31-
from forge.util.ops import compute_logprobs
3232
from omegaconf import DictConfig, OmegaConf
3333

3434
logger = get_logger("INFO")
3535

3636

37-
# TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
38-
# Currently duplicated because of function signature differences:
39-
# - This function takes logits + response, computes logprobs internally
40-
# - SimpleGRPOLoss takes pre-computed logprobs
41-
# - TitanTrainer passes logits, so would need wrapper or signature change
42-
# Consider refactoring TitanTrainer's loss interface to standardize this.
43-
def simple_grpo_loss(
44-
logits: torch.Tensor,
45-
response: torch.Tensor,
46-
ref_logprobs: torch.Tensor,
47-
advantages: torch.Tensor,
48-
padding_mask: torch.Tensor,
49-
beta: float = 1e-6,
50-
) -> torch.Tensor:
51-
logprobs: torch.Tensor = compute_logprobs(logits, response)
52-
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
53-
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
54-
55-
# Compute mean KL per valid token
56-
mean_kl = (
57-
((kl * padding_mask).sum(dim=1)) / (padding_mask.sum(dim=1).clamp(min=1.0))
58-
).mean()
59-
60-
# Compute mean policy loss per valid token
61-
mean_policy_loss = (
62-
((per_token_policy_loss * padding_mask).sum(dim=1))
63-
/ (padding_mask.sum(dim=1).clamp(min=1.0))
64-
).mean()
65-
66-
# Compute loss using the means (mathematically equivalent)
67-
loss = -(mean_policy_loss - beta * mean_kl)
68-
69-
# Log metrics
70-
# TODO: Better design - have loss function return all metrics as a dict,
71-
# then record them in rl_trainer so all training metrics are in one namespace
72-
# and we avoid doing .item here, which is not compile friendly
73-
record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN)
74-
record_metric(
75-
"grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX
76-
)
77-
record_metric(
78-
"grpo_loss/policy_gradient_loss", mean_policy_loss.item(), Reduce.MEAN
79-
)
80-
record_metric("grpo_loss/total_loss", loss.item(), Reduce.MEAN)
81-
record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN)
82-
record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN)
83-
return loss
84-
85-
8637
async def main(cfg: DictConfig):
8738
"""Main GRPO training loop with rollout and training processes."""
8839
# Convert OmegaConf config to plain dict
@@ -116,8 +67,32 @@ async def main(cfg: DictConfig):
11667
backend_config=metric_logging_cfg, run_config=run_config_for_logging
11768
)
11869

70+
# ---- Setup loss function ---- #
71+
loss_fn = GRPOLoss(
72+
clip_low=0.2,
73+
clip_high=0.28,
74+
beta=0.1,
75+
agg_type="fixed_horizon",
76+
)
77+
78+
# Fail-fast: Check loss/ref_model compatibility before spawning actors
79+
uses_ref_model = cfg.get("services", {}).get("ref_model") is not None
80+
if uses_ref_model and not isinstance(loss_fn, GRPOLoss):
81+
raise ValueError(
82+
f"ref_model is configured but {type(loss_fn).__name__} does not use ref_logprobs. "
83+
"Either remove the ref_model service config or use GRPOLoss with beta > 0."
84+
)
85+
if isinstance(loss_fn, GRPOLoss) and loss_fn.beta > 0 and not uses_ref_model:
86+
raise ValueError(
87+
f"GRPOLoss with beta={loss_fn.beta} requires ref_logprobs, but ref_model is not configured. "
88+
"Either add ref_model to services config or set beta=0."
89+
)
90+
11991
# ---- Setup services ---- #
12092

93+
async def noop():
94+
return None
95+
12196
(
12297
dataloader,
12398
generator,
@@ -130,13 +105,17 @@ async def main(cfg: DictConfig):
130105
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
131106
Generator.options(**cfg.services.generator).as_service(**cfg.generator),
132107
TitanTrainer.options(**cfg.actors.trainer).as_actor(
133-
**cfg.trainer, loss=simple_grpo_loss
108+
**cfg.trainer, loss=loss_fn
134109
),
135110
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
136111
**cfg.replay_buffer, collate=collate
137112
),
138113
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
139-
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
114+
(
115+
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model)
116+
if uses_ref_model
117+
else noop()
118+
),
140119
RewardActor.options(**cfg.services.reward_actor).as_service(
141120
reward_functions=[MathReward(), ThinkingReward()]
142121
),
@@ -187,7 +166,34 @@ async def continuous_rollouts():
187166
(group_size, max_req_tokens + max_res_tokens),
188167
dtype=torch.long,
189168
)
169+
seq_len = max_req_tokens + max_res_tokens
170+
190171
for i, response in enumerate(responses):
172+
# Validate logprobs exist
173+
if response.logprobs is None:
174+
raise ValueError(
175+
"Completion.logprobs is None. "
176+
"Ensure Generator returns logprobs by setting 'logprobs: 1' in sampling_params config."
177+
)
178+
179+
# Prepare generator_logprobs
180+
# Shift by -1 to align with next-token prediction
181+
actual_response_len = response.token_ids.shape[0]
182+
generator_logprobs = torch.zeros(seq_len, dtype=response.logprobs.dtype)
183+
generator_logprobs[
184+
max_req_tokens : max_req_tokens + actual_response_len
185+
] = response.logprobs
186+
generator_logprobs = torch.roll(generator_logprobs, shifts=-1, dims=0)
187+
generator_logprobs[-1] = 0.0
188+
189+
# Prepare loss_mask
190+
response_mask = torch.zeros(seq_len, dtype=torch.float32)
191+
response_mask[max_req_tokens : max_req_tokens + actual_response_len] = (
192+
1.0
193+
)
194+
loss_mask = torch.roll(response_mask, shifts=-1, dims=0)
195+
loss_mask[-1] = 0.0
196+
191197
episode = Episode(
192198
episode_id=str(uuid.uuid4()),
193199
pad_id=pad_id,
@@ -197,7 +203,10 @@ async def continuous_rollouts():
197203
request=prompt,
198204
response=response.text,
199205
completion=response,
206+
generator_logprobs=generator_logprobs,
207+
loss_mask=loss_mask,
200208
)
209+
201210
(
202211
episode.reward_breakdown,
203212
episode.reward,
@@ -263,21 +272,33 @@ async def continuous_rollouts():
263272

264273
t.step("reward_evaluation")
265274

266-
ref_logprobs = await ref_model.forward.route(
267-
input_ids, max_req_tokens, return_logprobs=True
268-
)
269-
t.step("reference_model_calculate_logprobs")
275+
# Compute ref_logprobs only if ref_model is configured
276+
if ref_model is not None:
277+
ref_logprobs = await ref_model.forward.route(
278+
input_ids, return_logprobs=True
279+
)
280+
t.step("reference_model_calculate_logprobs")
281+
282+
for i, episode in enumerate(episodes):
283+
episode.ref_logprobs = ref_logprobs[i] # [seq_len]
270284

271-
for i, episode in enumerate(episodes):
272-
episode.ref_logprobs = ref_logprobs[i]
273-
del ref_logprobs, input_ids
285+
del ref_logprobs
286+
287+
del input_ids
274288

275289
advantages = await compute_advantages.compute.call_one(episodes)
276290
for episode, advantage in zip(episodes, advantages):
277291
episode.advantage = advantage
278292
await replay_buffer.add.call_one(episode)
279293

280-
sample = episode.to_dict(exclude=["ref_logprobs", "completion"])
294+
sample = episode.to_dict(
295+
exclude=[
296+
"completion",
297+
"loss_mask",
298+
"generator_logprobs",
299+
"ref_logprobs",
300+
]
301+
)
281302
sample["score"] = sample["reward"]
282303
record_metric(
283304
"main_samples/continuous_rollouts/sample_table",

apps/grpo/qwen3_1_7b.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ generator:
4343
max_tokens: ${max_res_tokens}
4444
temperature: 1.0
4545
top_p: 1.0
46+
logprobs: 1 # returns log probabilities for sampled tokens
4647

4748
# Trainer configuration
4849
trainer:

apps/grpo/qwen3_8b.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ generator:
3939
max_tokens: ${max_res_tokens}
4040
temperature: 1.0
4141
top_p: 1.0
42+
logprobs: 1 # returns log probabilities for sampled tokens
4243

4344
# Trainer configuration
4445
trainer:

src/forge/actors/reference_model.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
from forge.controller import ForgeActor
1616
from forge.observability.metrics import record_metric, Reduce
1717
from forge.observability.perf_tracker import Tracer
18-
from forge.util.ops import compute_logprobs
18+
from forge.rl.loss import compute_logprobs, create_shifted_targets
1919
from monarch.actor import current_rank, current_size, endpoint
2020
from torch.distributed.tensor import DTensor
21-
from torch.distributed.tensor.parallel import loss_parallel
2221
from torchtitan.config.job_config import (
2322
Checkpoint,
2423
Comm,
@@ -96,9 +95,9 @@ def __post_init__(self):
9695
self.rank = current_rank().rank
9796
self.size = math.prod(current_size().values())
9897

99-
self.compute_log_probs = compute_logprobs
98+
self.compute_logprobs = compute_logprobs
10099
if self.compile.enable:
101-
self.compute_log_probs = torch.compile(self.compute_log_probs)
100+
self.compute_logprobs = torch.compile(self.compute_logprobs)
102101

103102
env = {
104103
"RANK": str(self.rank),
@@ -128,23 +127,20 @@ async def setup(self):
128127

129128
@endpoint
130129
async def forward(
131-
self, input_ids: torch.Tensor, max_req_tokens: int, return_logprobs: bool
130+
self, input_ids: torch.Tensor, return_logprobs: bool = True
132131
) -> torch.Tensor:
133132
"""
134133
Args:
135-
input_ids (torch.Tensor): input token ids with shape [group_size, req + res length].
136-
max_req_tokens (int): maximum request length.
134+
input_ids (torch.Tensor): input token ids with shape [group_size, seq_len].
137135
return_logprobs (bool): whether to return log probabilities instead of raw logits.
138136
139137
return_logprobs flag significantly impacts the amount of data transferred to the caller:
140-
- When False: Returns logits with shape [group_size, req + res_length, vocab_size].
138+
- When False: Returns logits with shape [group_size, seq_len, vocab_size].
141139
This includes the full vocabulary distribution for each token position.
142140
143-
- When True: Returns log probabilities with shape [group_size, req_length].
144-
This only includes probabilities for the request tokens, significantly reducing memory
145-
usage and transfer overhead.
141+
- When True: Returns log probabilities with shape [group_size, seq_len].
142+
Prompt positions will have logprobs = 0.
146143
"""
147-
# Record reference model metrics
148144
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
149145

150146
t = Tracer("reference_perf/forward", timer="gpu", track_memory=True)
@@ -175,24 +171,16 @@ async def forward(
175171
with self.engine.maybe_enable_amp:
176172
with torch.inference_mode():
177173
logits = self.model(input_ids)
178-
self.step += 1
179174

180-
if not return_logprobs:
181-
if isinstance(logits, DTensor):
182-
logits = logits.full_tensor()
183-
t.stop()
184-
return logits
185-
else:
186-
response_tokens = input_ids[:, max_req_tokens:]
187-
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
188-
with loss_parallel():
189-
logprobs = self.compute_log_probs(logits, response_tokens)
190-
191-
# loss_parallel produces Replicated output - to_local() returns the full tensor
192-
logprobs = logprobs.to_local()
193-
else:
194-
if isinstance(logits, DTensor):
195-
logits = logits.full_tensor()
196-
logprobs = self.compute_log_probs(logits, response_tokens)
197-
t.stop()
198-
return logprobs
175+
if return_logprobs:
176+
target_ids = create_shifted_targets(input_ids)
177+
logprobs, _ = self.compute_logprobs(logits, target_ids)
178+
179+
out = logprobs if return_logprobs else logits
180+
181+
if isinstance(out, DTensor):
182+
out = out.full_tensor()
183+
184+
self.step += 1
185+
t.stop()
186+
return out

src/forge/actors/trainer/titan.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from forge.data.utils import batch_to_device
2020
from forge.observability.metrics import record_metric, Reduce
2121
from forge.observability.perf_tracker import Tracer
22+
from forge.rl.loss import create_shifted_targets
2223
from monarch.actor import endpoint
2324
from torch import Tensor
2425
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
@@ -122,15 +123,34 @@ def forward_backward(
122123
model_parts = self.engine.model_parts
123124
parallel_dims = self.engine.parallel_dims
124125
optional_context_parallel_ctx = None
126+
127+
# Create shifted target_ids for next-token prediction
128+
# target_ids[i] = input_ids[i+1], with loss_mask applied
129+
targets["target_ids"] = create_shifted_targets(
130+
inputs["tokens"], targets.get("loss_mask")
131+
)
132+
125133
if parallel_dims.pp_enabled:
126134
raise NotImplementedError("PP not implemented yet")
127135
else:
128136
with self.engine.train_context(optional_context_parallel_ctx):
129137
assert len(model_parts) == 1
130138
with self.engine.maybe_enable_amp:
131139
logits = model_parts[0](**inputs)
132-
loss = self.loss(logits, **targets)
133-
del logits # Free to before bwd to avoid peaking memory
140+
loss_output = self.loss(logits, **targets)
141+
loss = loss_output.loss
142+
143+
# Record metrics from loss output
144+
for metric in loss_output.metrics:
145+
value = (
146+
metric.value.item()
147+
if isinstance(metric.value, torch.Tensor)
148+
else metric.value
149+
)
150+
record_metric(metric.key, value, metric.reduction, metric.timestamp)
151+
152+
# Free to before bwd to avoid peaking memory
153+
del logits, loss_output.metrics
134154
loss.backward()
135155
self._accumulated_microbatches += 1
136156
return loss

src/forge/losses/grpo_loss.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)