Skip to content

Commit a111042

Browse files
felipemello1Felipe Mello
andauthored
[FIX] Drop only the sample truncated instead of the entire group (#744)
Co-authored-by: Felipe Mello <felipemello@fb.com>
1 parent 76abbcc commit a111042

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

apps/grpo/main.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -232,37 +232,19 @@ async def continuous_rollouts():
232232
"episode/min_response_tokens", response_tokens, Reduce.MIN
233233
)
234234

235-
# drop episodes if
236-
# 1> reward std-dev is very small (including all 0s and all 1s)
237-
# 2> any response was truncated (didn't end with EOS)
238-
# TODO: change it to filter only truncated episodes instead of dropping entire group
235+
# Drop entire group if reward std-dev is very small (no learning signal)
239236
rewards = [e.reward for e in episodes]
240237
rewards_std = torch.std(torch.tensor(rewards))
241238
is_low_variance = rewards_std < 1e-3
242-
num_truncated = sum(
243-
1 for e in episodes if e.completion.stop_reason == "length"
244-
)
245-
is_truncated = num_truncated > 0
246-
drop = is_low_variance or is_truncated
247239

248240
n = len(episodes)
249241
record_metric(
250242
"main/continuous_rollouts/episodes_dropped/low_variance",
251243
n if is_low_variance else 0,
252244
Reduce.SUM,
253245
)
254-
record_metric(
255-
"main/continuous_rollouts/episodes_dropped/truncated",
256-
num_truncated,
257-
Reduce.SUM,
258-
)
259-
record_metric(
260-
"main/continuous_rollouts/episodes_dropped/total",
261-
n if drop else 0,
262-
Reduce.SUM,
263-
)
246+
if is_low_variance:
264247

265-
if drop:
266248
del input_ids, episodes
267249
continue
268250

@@ -283,8 +265,16 @@ async def continuous_rollouts():
283265
del input_ids
284266

285267
advantages = await compute_advantages.compute.call_one(episodes)
268+
num_truncated = 0
286269
for episode, advantage in zip(episodes, advantages):
287270
episode.advantage = advantage
271+
272+
# Zero out loss_mask for truncated episodes so they don't contribute to gradient
273+
# TODO: evaluate if we should drop truncated episodes instead
274+
if episode.completion.stop_reason == "length":
275+
episode.loss_mask = torch.zeros_like(episode.loss_mask)
276+
num_truncated += 1
277+
288278
await replay_buffer.add.call_one(episode)
289279

290280
sample = episode.to_dict(
@@ -301,6 +291,11 @@ async def continuous_rollouts():
301291
sample,
302292
Reduce.SAMPLE,
303293
)
294+
record_metric(
295+
"main/continuous_rollouts/episodes_dropped/truncated",
296+
num_truncated,
297+
Reduce.SUM,
298+
)
304299

305300
rollout_count += 1
306301
record_metric(

0 commit comments

Comments
 (0)