Skip to content

Commit 75f93f6

Browse files
committed
server api
1 parent f6c4755 commit 75f93f6

5 files changed

Lines changed: 107 additions & 15 deletions

File tree

tests/unit/test_run.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ def test_run_log_calls_client_for_spaces(temp_dir):
5959
run.log(metrics)
6060

6161
time.sleep(0.6)
62-
_, kwargs = client.predict.call_args
63-
assert kwargs["api_name"] == "/bulk_log"
62+
bulk_log_calls = [
63+
call for call in client.predict.call_args_list
64+
if call.kwargs.get("api_name") == "/bulk_log"
65+
]
66+
assert len(bulk_log_calls) >= 1
67+
kwargs = bulk_log_calls[-1].kwargs
6468
assert len(kwargs["logs"]) == 1
6569
assert kwargs["logs"][0]["project"] == "proj"
6670
assert kwargs["logs"][0]["run"] == "run1"

trackio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _cleanup_current_run():
105105
if run is not None:
106106
try:
107107
if not run._finished:
108-
run.finish(status="failed")
108+
run.finish(status="failed", _atexit=True)
109109
except Exception:
110110
pass
111111

trackio/run.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@
2626
from trackio.sqlite_storage import SQLiteStorage
2727
from trackio.table import Table
2828
from trackio.trace import Trace
29-
from trackio.typehints import AlertEntry, LogEntry, SystemLogEntry, UploadEntry
29+
from trackio.typehints import (
30+
AlertEntry,
31+
LogEntry,
32+
RunStatusEntry,
33+
SystemLogEntry,
34+
UploadEntry,
35+
)
3036
from trackio.utils import MEDIA_DIR, _emit_nonfatal_warning, _get_default_namespace
3137

3238
BATCH_SEND_INTERVAL = 0.5
@@ -139,6 +145,7 @@ def __init__(
139145
self._queued_system_logs: list[SystemLogEntry] = []
140146
self._queued_uploads: list[UploadEntry] = []
141147
self._queued_alerts: list[AlertEntry] = []
148+
self._queued_run_status: list[RunStatusEntry] = []
142149
self._stop_flag = threading.Event()
143150
self._config_logged = False
144151
max_step = self._safe_get_max_step_for_run()
@@ -166,7 +173,20 @@ def __init__(
166173
description="remote Trackio logging thread",
167174
)
168175

169-
SQLiteStorage.set_run_status(self.project, self.name, "running", run_id=self.id)
176+
if self._is_local:
177+
SQLiteStorage.set_run_status(
178+
self.project, self.name, "running", run_id=self.id
179+
)
180+
else:
181+
with self._client_lock:
182+
self._queued_run_status.append(
183+
{
184+
"project": self.project,
185+
"run": self.name,
186+
"run_id": self.id,
187+
"status": "running",
188+
}
189+
)
170190
self._finished = False
171191

172192
self._gpu_monitor: "GpuMonitor | AppleGpuMonitor | None" = None
@@ -435,6 +455,7 @@ def _batch_sender(self):
435455
or len(self._queued_system_logs) > 0
436456
or len(self._queued_uploads) > 0
437457
or len(self._queued_alerts) > 0
458+
or len(self._queued_run_status) > 0
438459
or self._has_local_buffer
439460
):
440461
if not self._stop_flag.is_set():
@@ -522,6 +543,23 @@ def _batch_sender(self):
522543
self._write_alerts_to_sqlite(alerts_to_send)
523544
failed = True
524545

546+
if self._queued_run_status:
547+
status_to_send = self._queued_run_status.copy()
548+
self._queued_run_status.clear()
549+
try:
550+
for entry in status_to_send:
551+
self._client.predict(
552+
api_name="/set_run_status",
553+
project=entry["project"],
554+
run=entry["run"],
555+
status=entry["status"],
556+
run_id=entry.get("run_id"),
557+
hf_token=self._hf_token_for_remote(),
558+
)
559+
except Exception:
560+
self._queued_run_status[0:0] = status_to_send
561+
failed = True
562+
525563
if failed:
526564
consecutive_failures += 1
527565
else:
@@ -965,11 +1003,24 @@ def log_system(self, metrics: dict):
9651003
except Exception as e:
9661004
_emit_nonfatal_warning(f"trackio.log_system() failed: {e}")
9671005

968-
def finish(self, status: str = "finished"):
1006+
def finish(self, status: str = "finished", _atexit: bool = False):
9691007
if self._finished:
9701008
return
9711009
self._finished = True
9721010

1011+
join_timeout = 2 if _atexit else 30
1012+
1013+
if not self._is_local:
1014+
with self._client_lock:
1015+
self._queued_run_status.append(
1016+
{
1017+
"project": self.project,
1018+
"run": self.name,
1019+
"run_id": self.id,
1020+
"status": status,
1021+
}
1022+
)
1023+
9731024
try:
9741025
if self._gpu_monitor is not None:
9751026
try:
@@ -984,24 +1035,28 @@ def finish(self, status: str = "finished"):
9841035

9851036
if self._is_local:
9861037
if self._local_sender_thread is not None:
987-
print("* Run finished. Uploading logs to Trackio (please wait...)")
988-
self._local_sender_thread.join(timeout=30)
1038+
if not _atexit:
1039+
print(
1040+
"* Run finished. Uploading logs to Trackio (please wait...)"
1041+
)
1042+
self._local_sender_thread.join(timeout=join_timeout)
9891043
if self._local_sender_thread.is_alive():
9901044
_emit_nonfatal_warning(
991-
"Could not flush all logs within 30s. Some data may be buffered locally."
1045+
f"Could not flush all logs within {join_timeout}s. Some data may be buffered locally."
9921046
)
9931047
else:
9941048
with self._client_lock:
9951049
self._flush_queues_inline()
9961050
else:
9971051
if self._client_thread is not None:
998-
print(
999-
"* Run finished. Uploading logs to the remote Trackio server (please wait...)"
1000-
)
1001-
self._client_thread.join(timeout=30)
1052+
if not _atexit:
1053+
print(
1054+
"* Run finished. Uploading logs to the remote Trackio server (please wait...)"
1055+
)
1056+
self._client_thread.join(timeout=join_timeout)
10021057
if self._client_thread.is_alive():
10031058
_emit_nonfatal_warning(
1004-
"Could not flush all logs within 30s. Some data may be buffered locally."
1059+
f"Could not flush all logs within {join_timeout}s. Some data may be buffered locally."
10051060
)
10061061
else:
10071062
with self._client_lock:
@@ -1029,4 +1084,13 @@ def finish(self, status: str = "finished"):
10291084
except Exception as e:
10301085
_emit_nonfatal_warning(f"trackio.finish() failed: {e}")
10311086

1032-
SQLiteStorage.set_run_status(self.project, self.name, status, run_id=self.id)
1087+
if self._is_local:
1088+
SQLiteStorage.set_run_status(
1089+
self.project, self.name, status, run_id=self.id
1090+
)
1091+
elif self._queued_run_status:
1092+
self._warn_once(
1093+
"finish-remote-status",
1094+
f"trackio.finish() could not record run status '{status}' on the remote server. "
1095+
"The dashboard may not reflect the final state of this run.",
1096+
)

trackio/server.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,16 +619,19 @@ def bulk_alert(
619619
"steps": [],
620620
"timestamps": [],
621621
"alert_ids": [],
622+
"data_list": [],
622623
}
623624
alerts_by_run[key]["titles"].append(entry["title"])
624625
alerts_by_run[key]["texts"].append(entry.get("text"))
625626
alerts_by_run[key]["levels"].append(entry["level"])
626627
alerts_by_run[key]["steps"].append(entry.get("step"))
627628
alerts_by_run[key]["timestamps"].append(entry.get("timestamp"))
628629
alerts_by_run[key]["alert_ids"].append(entry.get("alert_id"))
630+
alerts_by_run[key]["data_list"].append(entry.get("data"))
629631

630632
for (project, run, run_id), data in alerts_by_run.items():
631633
has_alert_ids = any(aid is not None for aid in data["alert_ids"])
634+
has_data = any(d is not None for d in data["data_list"])
632635
payload = dict(
633636
project=project,
634637
run=run,
@@ -639,13 +642,26 @@ def bulk_alert(
639642
steps=data["steps"],
640643
timestamps=data["timestamps"],
641644
alert_ids=data["alert_ids"] if has_alert_ids else None,
645+
data_list=data["data_list"] if has_data else None,
642646
)
643647
try:
644648
SQLiteStorage.bulk_alert(**payload)
645649
except sqlite3.OperationalError:
646650
_enqueue_write("bulk_alert", payload)
647651

648652

653+
def set_run_status(
654+
request: Request,
655+
project: str,
656+
run: str,
657+
status: str,
658+
run_id: str | None,
659+
hf_token: str | None,
660+
) -> None:
661+
assert_can_write_metrics(request, hf_token)
662+
SQLiteStorage.set_run_status(project, run, status, run_id=run_id)
663+
664+
649665
def get_alerts(
650666
project: str,
651667
run: str | None = None,
@@ -949,6 +965,7 @@ def _api_registry() -> dict[str, Any]:
949965
"bulk_log": bulk_log,
950966
"bulk_log_system": bulk_log_system,
951967
"bulk_alert": bulk_alert,
968+
"set_run_status": set_run_status,
952969
"get_alerts": get_alerts,
953970
"get_metric_values": get_metric_values,
954971
"get_runs_for_project": get_runs_for_project,

trackio/typehints.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ class AlertEntry(TypedDict, total=False):
3535
data: dict | None
3636

3737

38+
class RunStatusEntry(TypedDict, total=False):
39+
project: str
40+
run: str
41+
run_id: str | None
42+
status: str
43+
44+
3845
class UploadEntry(TypedDict):
3946
project: str
4047
run: str | None

0 commit comments

Comments
 (0)