Skip to content

Commit 937310a

Browse files
felipemello1Felipe Mello
andauthored
surface errors (#738)
Co-authored-by: Felipe Mello <felipemello@fb.com>
1 parent 898be00 commit 937310a

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

apps/grpo/main.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88

99
import asyncio
10+
import traceback
1011
import uuid
1112

1213
import torch
@@ -311,7 +312,9 @@ async def continuous_training():
311312
training_step = 0
312313
restart_tracer = True # Flag to control when to restart tracer
313314

314-
while max_steps == -1 or training_step < max_steps:
315+
while (
316+
max_steps == -1 or training_step < max_steps
317+
) and not shutdown_event.is_set():
315318
# Restart tracer when needed (initial start or after completing a training step)
316319
# Otherwise, we cannot measure time waiting for buffer
317320
if restart_tracer:
@@ -347,9 +350,12 @@ async def continuous_training():
347350
# Flush metrics every training step to WandB
348351
await mlogger.flush.call_one(training_step)
349352

350-
print(
351-
f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
352-
)
353+
if shutdown_event.is_set():
354+
print("Training stopped due to shutdown event (likely a task failure).")
355+
else:
356+
print(
357+
f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
358+
)
353359

354360
num_rollout_threads = cfg.get("rollout_threads", 1)
355361
num_training_threads = cfg.get("training_threads", 1)
@@ -361,10 +367,27 @@ async def continuous_training():
361367
]
362368
training_task = asyncio.create_task(continuous_training())
363369

370+
# Surface background task failures and trigger shutdown (fail-fast)
371+
def on_task_done(task: asyncio.Task, name: str):
372+
if task.cancelled():
373+
return
374+
exc = task.exception()
375+
if exc:
376+
print(f"ERROR: {name} failed: {type(exc).__name__}: {exc}")
377+
traceback.print_exception(type(exc), exc, exc.__traceback__)
378+
shutdown_event.set()
379+
380+
for i, task in enumerate(rollout_tasks):
381+
task.add_done_callback(lambda t, i=i: on_task_done(t, f"rollout_task_{i}"))
382+
training_task.add_done_callback(lambda t: on_task_done(t, "training_task"))
383+
364384
try:
365385
await training_task
366386
except KeyboardInterrupt:
367387
print("Training interrupted by user")
388+
except Exception as e:
389+
print(f"ERROR: Training task failed: {type(e).__name__}: {e}")
390+
traceback.print_exc()
368391
finally:
369392
print("Shutting down... (this may take a few seconds)")
370393
shutdown_event.set()

0 commit comments

Comments
 (0)