2626from trackio .sqlite_storage import SQLiteStorage
2727from trackio .table import Table
2828from trackio .trace import Trace
29- from trackio .typehints import AlertEntry , LogEntry , SystemLogEntry , UploadEntry
29+ from trackio .typehints import (
30+ AlertEntry ,
31+ LogEntry ,
32+ RunStatusEntry ,
33+ SystemLogEntry ,
34+ UploadEntry ,
35+ )
3036from trackio .utils import MEDIA_DIR , _emit_nonfatal_warning , _get_default_namespace
3137
3238BATCH_SEND_INTERVAL = 0.5
@@ -139,6 +145,7 @@ def __init__(
139145 self ._queued_system_logs : list [SystemLogEntry ] = []
140146 self ._queued_uploads : list [UploadEntry ] = []
141147 self ._queued_alerts : list [AlertEntry ] = []
148+ self ._queued_run_status : list [RunStatusEntry ] = []
142149 self ._stop_flag = threading .Event ()
143150 self ._config_logged = False
144151 max_step = self ._safe_get_max_step_for_run ()
@@ -166,7 +173,20 @@ def __init__(
166173 description = "remote Trackio logging thread" ,
167174 )
168175
169- SQLiteStorage .set_run_status (self .project , self .name , "running" , run_id = self .id )
176+ if self ._is_local :
177+ SQLiteStorage .set_run_status (
178+ self .project , self .name , "running" , run_id = self .id
179+ )
180+ else :
181+ with self ._client_lock :
182+ self ._queued_run_status .append (
183+ {
184+ "project" : self .project ,
185+ "run" : self .name ,
186+ "run_id" : self .id ,
187+ "status" : "running" ,
188+ }
189+ )
170190 self ._finished = False
171191
172192 self ._gpu_monitor : "GpuMonitor | AppleGpuMonitor | None" = None
@@ -435,6 +455,7 @@ def _batch_sender(self):
435455 or len (self ._queued_system_logs ) > 0
436456 or len (self ._queued_uploads ) > 0
437457 or len (self ._queued_alerts ) > 0
458+ or len (self ._queued_run_status ) > 0
438459 or self ._has_local_buffer
439460 ):
440461 if not self ._stop_flag .is_set ():
@@ -522,6 +543,23 @@ def _batch_sender(self):
522543 self ._write_alerts_to_sqlite (alerts_to_send )
523544 failed = True
524545
546+ if self ._queued_run_status :
547+ status_to_send = self ._queued_run_status .copy ()
548+ self ._queued_run_status .clear ()
549+ try :
550+ for entry in status_to_send :
551+ self ._client .predict (
552+ api_name = "/set_run_status" ,
553+ project = entry ["project" ],
554+ run = entry ["run" ],
555+ status = entry ["status" ],
556+ run_id = entry .get ("run_id" ),
557+ hf_token = self ._hf_token_for_remote (),
558+ )
559+ except Exception :
560+ self ._queued_run_status [0 :0 ] = status_to_send
561+ failed = True
562+
525563 if failed :
526564 consecutive_failures += 1
527565 else :
@@ -965,11 +1003,24 @@ def log_system(self, metrics: dict):
9651003 except Exception as e :
9661004 _emit_nonfatal_warning (f"trackio.log_system() failed: { e } " )
9671005
968- def finish (self , status : str = "finished" ):
1006+ def finish (self , status : str = "finished" , _atexit : bool = False ):
9691007 if self ._finished :
9701008 return
9711009 self ._finished = True
9721010
1011+ join_timeout = 2 if _atexit else 30
1012+
1013+ if not self ._is_local :
1014+ with self ._client_lock :
1015+ self ._queued_run_status .append (
1016+ {
1017+ "project" : self .project ,
1018+ "run" : self .name ,
1019+ "run_id" : self .id ,
1020+ "status" : status ,
1021+ }
1022+ )
1023+
9731024 try :
9741025 if self ._gpu_monitor is not None :
9751026 try :
@@ -984,24 +1035,28 @@ def finish(self, status: str = "finished"):
9841035
9851036 if self ._is_local :
9861037 if self ._local_sender_thread is not None :
987- print ("* Run finished. Uploading logs to Trackio (please wait...)" )
988- self ._local_sender_thread .join (timeout = 30 )
1038+ if not _atexit :
1039+ print (
1040+ "* Run finished. Uploading logs to Trackio (please wait...)"
1041+ )
1042+ self ._local_sender_thread .join (timeout = join_timeout )
9891043 if self ._local_sender_thread .is_alive ():
9901044 _emit_nonfatal_warning (
991- "Could not flush all logs within 30s . Some data may be buffered locally."
1045+ f "Could not flush all logs within { join_timeout } s . Some data may be buffered locally."
9921046 )
9931047 else :
9941048 with self ._client_lock :
9951049 self ._flush_queues_inline ()
9961050 else :
9971051 if self ._client_thread is not None :
998- print (
999- "* Run finished. Uploading logs to the remote Trackio server (please wait...)"
1000- )
1001- self ._client_thread .join (timeout = 30 )
1052+ if not _atexit :
1053+ print (
1054+ "* Run finished. Uploading logs to the remote Trackio server (please wait...)"
1055+ )
1056+ self ._client_thread .join (timeout = join_timeout )
10021057 if self ._client_thread .is_alive ():
10031058 _emit_nonfatal_warning (
1004- "Could not flush all logs within 30s . Some data may be buffered locally."
1059+ f "Could not flush all logs within { join_timeout } s . Some data may be buffered locally."
10051060 )
10061061 else :
10071062 with self ._client_lock :
@@ -1029,4 +1084,13 @@ def finish(self, status: str = "finished"):
10291084 except Exception as e :
10301085 _emit_nonfatal_warning (f"trackio.finish() failed: { e } " )
10311086
1032- SQLiteStorage .set_run_status (self .project , self .name , status , run_id = self .id )
1087+ if self ._is_local :
1088+ SQLiteStorage .set_run_status (
1089+ self .project , self .name , status , run_id = self .id
1090+ )
1091+ elif self ._queued_run_status :
1092+ self ._warn_once (
1093+ "finish-remote-status" ,
1094+ f"trackio.finish() could not record run status '{ status } ' on the remote server. "
1095+ "The dashboard may not reflect the final state of this run." ,
1096+ )
0 commit comments