77# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88
99import asyncio
10+ import traceback
1011import uuid
1112
1213import torch
@@ -311,7 +312,9 @@ async def continuous_training():
311312 training_step = 0
312313 restart_tracer = True # Flag to control when to restart tracer
313314
314- while max_steps == - 1 or training_step < max_steps :
315+ while (
316+ max_steps == - 1 or training_step < max_steps
317+ ) and not shutdown_event .is_set ():
315318 # Restart tracer when needed (initial start or after completing a training step)
316319 # Otherwise, we cannot measure time waiting for buffer
317320 if restart_tracer :
@@ -347,9 +350,12 @@ async def continuous_training():
347350 # Flush metrics every training step to WandB
348351 await mlogger .flush .call_one (training_step )
349352
350- print (
351- f"Reached training limit ({ max_steps } steps). Exiting continuous_training loop."
352- )
353+ if shutdown_event .is_set ():
354+ print ("Training stopped due to shutdown event (likely a task failure)." )
355+ else :
356+ print (
357+ f"Reached training limit ({ max_steps } steps). Exiting continuous_training loop."
358+ )
353359
354360 num_rollout_threads = cfg .get ("rollout_threads" , 1 )
355361 num_training_threads = cfg .get ("training_threads" , 1 )
@@ -361,10 +367,27 @@ async def continuous_training():
361367 ]
362368 training_task = asyncio .create_task (continuous_training ())
363369
370+ # Surface background task failures and trigger shutdown (fail-fast)
371+ def on_task_done (task : asyncio .Task , name : str ):
372+ if task .cancelled ():
373+ return
374+ exc = task .exception ()
375+ if exc :
376+ print (f"ERROR: { name } failed: { type (exc ).__name__ } : { exc } " )
377+ traceback .print_exception (type (exc ), exc , exc .__traceback__ )
378+ shutdown_event .set ()
379+
380+ for i , task in enumerate (rollout_tasks ):
381+ task .add_done_callback (lambda t , i = i : on_task_done (t , f"rollout_task_{ i } " ))
382+ training_task .add_done_callback (lambda t : on_task_done (t , "training_task" ))
383+
364384 try :
365385 await training_task
366386 except KeyboardInterrupt :
367387 print ("Training interrupted by user" )
388+ except Exception as e :
389+ print (f"ERROR: Training task failed: { type (e ).__name__ } : { e } " )
390+ traceback .print_exc ()
368391 finally :
369392 print ("Shutting down... (this may take a few seconds)" )
370393 shutdown_event .set ()
0 commit comments