@@ -232,37 +232,19 @@ async def continuous_rollouts():
232232 "episode/min_response_tokens" , response_tokens , Reduce .MIN
233233 )
234234
235- # drop episodes if
236- # 1> reward std-dev is very small (including all 0s and all 1s)
237- # 2> any response was truncated (didn't end with EOS)
238- # TODO: change it to filter only truncated episodes instead of dropping entire group
235+ # Drop entire group if reward std-dev is very small (no learning signal)
239236 rewards = [e .reward for e in episodes ]
240237 rewards_std = torch .std (torch .tensor (rewards ))
241238 is_low_variance = rewards_std < 1e-3
242- num_truncated = sum (
243- 1 for e in episodes if e .completion .stop_reason == "length"
244- )
245- is_truncated = num_truncated > 0
246- drop = is_low_variance or is_truncated
247239
248240 n = len (episodes )
249241 record_metric (
250242 "main/continuous_rollouts/episodes_dropped/low_variance" ,
251243 n if is_low_variance else 0 ,
252244 Reduce .SUM ,
253245 )
254- record_metric (
255- "main/continuous_rollouts/episodes_dropped/truncated" ,
256- num_truncated ,
257- Reduce .SUM ,
258- )
259- record_metric (
260- "main/continuous_rollouts/episodes_dropped/total" ,
261- n if drop else 0 ,
262- Reduce .SUM ,
263- )
246+ if is_low_variance :
264247
265- if drop :
266248 del input_ids , episodes
267249 continue
268250
@@ -283,8 +265,16 @@ async def continuous_rollouts():
283265 del input_ids
284266
285267 advantages = await compute_advantages .compute .call_one (episodes )
268+ num_truncated = 0
286269 for episode , advantage in zip (episodes , advantages ):
287270 episode .advantage = advantage
271+
272+ # Zero out loss_mask for truncated episodes so they don't contribute to gradient
273+ # TODO: evaluate if we should drop truncated episodes instead
274+ if episode .completion .stop_reason == "length" :
275+ episode .loss_mask = torch .zeros_like (episode .loss_mask )
276+ num_truncated += 1
277+
288278 await replay_buffer .add .call_one (episode )
289279
290280 sample = episode .to_dict (
@@ -301,6 +291,11 @@ async def continuous_rollouts():
301291 sample ,
302292 Reduce .SAMPLE ,
303293 )
294+ record_metric (
295+ "main/continuous_rollouts/episodes_dropped/truncated" ,
296+ num_truncated ,
297+ Reduce .SUM ,
298+ )
304299
305300 rollout_count += 1
306301 record_metric (
0 commit comments