Skip to content

Commit 73f17d0

Browse files
committed
misc fixes
1 parent 81e2561 commit 73f17d0

5 files changed

Lines changed: 76 additions & 30 deletions

File tree

.changeset/every-spoons-smash.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,11 @@
22
"trackio": minor
33
---
44

5-
feat:Add additional support for autonomous ML experiments
5+
feat: Add additional support for autonomous ML experiments
6+
7+
- `trackio.watch()` / `trackio.should_stop()`: register metric watchers (NaN/Inf, threshold, spike, stagnation, custom fn) that fire alerts automatically on every `trackio.log()` call
8+
- `AlertReason` constants for programmatic alert filtering
9+
- Run lifecycle status tracking (`running``finished` / `failed`) persisted in SQLite
10+
- New CLI commands: `trackio best`, `trackio compare`, `trackio summary`
11+
- `Run.status`, `Run.final_metrics`, `Run.metrics()`, `Run.history()` on the Python API
12+
- `alerts.data` column (SQL migration) for structured alert metadata

docs/source/alerts.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Watcher-generated alerts are stored, displayed in the dashboard, and delivered t
7373
|---|---|---|---|
7474
| `metric` | `str` | *(required)* | The metric name to watch (e.g., `"train/loss"`). |
7575
| `nan` | `bool` | `True` | Fire an ERROR alert if the value becomes NaN or Inf. |
76-
| `spike_factor` | `float \| None` | `None` | Fire a WARN alert when the value deviates from the recent moving average by this factor (e.g., `3.0` = 3× the average). |
76+
| `spike_factor` | `float \| None` | `None` | Fire a WARN alert when `\|value − recent_avg\| > (spike_factor − 1) × \|recent_avg\|` (e.g., `3.0` triggers when the deviation exceeds 2× `\|avg\|`). Symmetric — drops trigger too. |
7777
| `patience` | `int \| None` | `None` | Fire a WARN alert if no improvement is seen for this many log steps. Also sets `should_stop()` to `True`. |
7878
| `min_delta` | `float` | `0.0` | Minimum change to count as an improvement (used with `patience`). |
7979
| `max_value` | `float \| None` | `None` | Fire an ERROR alert if the value exceeds this threshold. Also sets `should_stop()` to `True`. |
@@ -93,7 +93,7 @@ trackio.watch("train/loss", nan=True)
9393

9494
#### Max / Min Thresholds
9595

96-
`max_value` fires an **ERROR** alert (and stops) when the metric exceeds the threshold. `min_value` fires a **WARN** alert when it falls below. Each alert fires once when the threshold is crossed and resets if the value recovers.
96+
`max_value` fires an **ERROR** alert (and stops) when the metric exceeds the threshold. `min_value` fires a **WARN** alert when it falls below, but — unlike `max_value` — does **not** set `should_stop()`. Each alert fires once when the threshold is crossed and resets if the value recovers.
9797

9898
```python
9999
trackio.watch("train/loss", max_value=20.0)
@@ -102,7 +102,7 @@ trackio.watch("val/accuracy", min_value=0.5)
102102

103103
#### Spike Detection
104104

105-
Fires a **WARN** alert when the value deviates from the recent moving average by more than `(spike_factor - 1) × avg`. The alert resets automatically once the value returns to normal.
105+
Fires a **WARN** alert when the value deviates from the recent moving average by more than `(spike_factor - 1) × |recent_avg|` — that is, when `|value − recent_avg| > (spike_factor − 1) × |recent_avg|`. Detection is symmetric: sudden drops trigger the alert in addition to sudden rises. With `spike_factor=3.0` and a recent average of `1.0`, the alert fires once `|value − 1.0| > 2.0`. The alert resets automatically once the value returns to normal.
106106

107107
```python
108108
trackio.watch("train/loss", spike_factor=3.0, window=10)
@@ -118,7 +118,7 @@ trackio.watch("val/accuracy", patience=50, min_delta=0.001, mode="max")
118118

119119
### Early Stopping
120120

121-
[`should_stop`] returns `True` if any watcher has triggered a stop condition (NaN/Inf, `max_value` exceeded, or `patience` exhausted):
121+
[`should_stop`] returns `True` if any watcher has triggered a stop condition (NaN/Inf, `max_value` exceeded, `patience` exhausted, or a custom watcher returned `{"stop": True}`):
122122

123123
```python
124124
for step in range(1000):

trackio/__init__.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def init(
684684
globals()["config"] = run.config
685685
if _watcher_manager._watchers:
686686
_emit_nonfatal_warning(
687-
"trackio.init() cleared existing metric watchers. Call trackio.watch() after trackio.init()."
687+
"trackio.init() will clear existing metric watchers. Call trackio.watch() after trackio.init()."
688688
)
689689
_watcher_manager.clear()
690690

@@ -835,6 +835,7 @@ def alert(
835835
run.alert(title=title, text=text, level=level, webhook_url=webhook_url, data=data)
836836

837837

838+
# Not thread-safe: concurrent trackio.log() from multiple threads may race on watcher state.
838839
_watcher_manager = WatcherManager()
839840

840841

@@ -852,17 +853,19 @@ def watch(
852853
) -> None:
853854
"""
854855
Register a metric watcher that automatically fires alerts when conditions
855-
are met during ``trackio.log()`` calls. Must be called after
856-
``trackio.init()`` — watchers are cleared when a new run starts.
856+
are met during ``trackio.log()`` calls. Typically called after
857+
``trackio.init()`` — watchers registered earlier will persist until the
858+
next ``trackio.init()`` clears them.
857859
858860
Args:
859861
metric (`str`):
860862
The metric name to watch (e.g., ``"train/loss"``).
861863
nan (`bool`, *optional*, defaults to `True`):
862864
Fire an ERROR alert if the metric becomes NaN or Inf.
863865
spike_factor (`float`, *optional*):
864-
Fire a WARN alert if the value exceeds the recent moving average
865-
by this factor (e.g., ``3.0`` means 3x the recent average).
866+
Fire a WARN alert if the absolute deviation from the recent moving
867+
average exceeds ``(spike_factor - 1) * |recent_avg|`` (e.g.,
868+
``3.0`` triggers when ``|value - avg| > 2 * |avg|``).
866869
patience (`int`, *optional*):
867870
Fire a WARN alert if no improvement is seen for this many log
868871
steps. Also sets ``should_stop()`` to True.
@@ -877,14 +880,18 @@ def watch(
877880
Number of recent values to use for spike detection averaging.
878881
mode (`str`, *optional*, defaults to ``"min"``):
879882
Whether lower (``"min"``) or higher (``"max"``) values are better.
880-
Affects patience-based stagnation detection.
883+
Must be ``"min"`` or ``"max"``. Affects patience-based stagnation
884+
detection.
881885
fn (`Callable[[float, int | None], bool | list[dict] | None]`, *optional*):
882886
A custom condition called as ``fn(value, step)`` on every
883-
``trackio.log()`` call. Return ``True`` to fire a default WARN
884-
alert, a list of alert dicts for full control, or a falsy value
885-
for no alert. Include ``"stop": True`` in a returned dict to
886-
also set ``should_stop()`` to ``True``.
887+
``trackio.log()`` call (where ``value`` is the most recent metric
888+
value and ``step`` is the log step or ``None``). Return ``True``
889+
to fire a default WARN alert, a list of alert dicts for full
890+
control, or a falsy value for no alert. Include ``"stop": True``
891+
in a returned dict to also set ``should_stop()`` to ``True``.
887892
"""
893+
if mode not in ("min", "max"):
894+
raise ValueError(f"trackio.watch(): mode={mode!r}; expected 'min' or 'max'.")
888895
watcher = MetricWatcher(
889896
metric_name=metric,
890897
nan=nan,

trackio/cli.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def _handle_config(args):
150150

151151

152152
def _extract_reports(
153-
run: str, logs: list[dict], report_name: str | None = None
153+
run: str,
154+
logs: list[dict],
155+
report_name: str | None = None,
156+
run_id: str | None = None,
154157
) -> list[dict]:
155158
reports = []
156159
for log in logs:
@@ -165,6 +168,7 @@ def _extract_reports(
165168
reports.append(
166169
{
167170
"run": run,
171+
"run_id": run_id,
168172
"report": key,
169173
"step": step,
170174
"timestamp": timestamp,
@@ -887,9 +891,16 @@ def main():
887891
if trailing_globals.hf_token is not None:
888892
args.hf_token = trailing_globals.hf_token
889893

890-
if args.command in ("show", "status", "sync", "freeze", "skills") and _get_space(
891-
args
892-
):
894+
if args.command in (
895+
"show",
896+
"status",
897+
"sync",
898+
"freeze",
899+
"skills",
900+
"best",
901+
"compare",
902+
"summary",
903+
) and _get_space(args):
893904
error_exit(
894905
f"The '{args.command}' command does not support --space (remote mode)."
895906
)
@@ -1064,21 +1075,33 @@ def main():
10641075
run_records = remote.predict(
10651076
args.project, api_name="/get_runs_for_project"
10661077
)
1067-
runs = [r["name"] if isinstance(r, dict) else r for r in run_records]
1078+
records = [
1079+
r if isinstance(r, dict) else {"name": r, "id": r}
1080+
for r in run_records
1081+
]
10681082
else:
10691083
_require_project(args.project)
1070-
runs = SQLiteStorage.get_runs(args.project)
1071-
if args.run and args.run not in runs:
1084+
records = SQLiteStorage.get_run_records(args.project)
1085+
1086+
run_names = [r["name"] for r in records]
1087+
if args.run and args.run not in run_names:
10721088
error_exit(f"Run '{args.run}' not found in project '{args.project}'.")
10731089

1074-
target_runs = [args.run] if args.run else runs
1090+
target_records = (
1091+
[r for r in records if r["name"] == args.run] if args.run else records
1092+
)
1093+
target_names = [r["name"] for r in target_records]
1094+
has_dupes = len(target_names) != len(set(target_names))
1095+
10751096
all_reports = []
1076-
for run_name in target_runs:
1097+
for rec in target_records:
1098+
run_name = rec["name"]
1099+
run_id = rec.get("id")
10771100
if remote:
10781101
logs = remote.predict(args.project, run_name, api_name="/get_logs")
10791102
else:
1080-
logs = SQLiteStorage.get_logs(args.project, run_name)
1081-
all_reports.extend(_extract_reports(run_name, logs))
1103+
logs = SQLiteStorage.get_logs(args.project, run_name, run_id=run_id)
1104+
all_reports.extend(_extract_reports(run_name, logs, run_id=run_id))
10821105

10831106
if args.json:
10841107
print(
@@ -1091,10 +1114,14 @@ def main():
10911114
)
10921115
)
10931116
else:
1094-
report_lines = [
1095-
f"{entry['run']} | {entry['report']} | step={entry['step']} | {entry['timestamp']}"
1096-
for entry in all_reports
1097-
]
1117+
report_lines = []
1118+
for entry in all_reports:
1119+
label = entry["run"]
1120+
if has_dupes and entry.get("run_id"):
1121+
label += f" ({entry['run_id'][:8]})"
1122+
report_lines.append(
1123+
f"{label} | {entry['report']} | step={entry['step']} | {entry['timestamp']}"
1124+
)
10981125
if args.run:
10991126
print(
11001127
format_list(

trackio/cli_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ def format_compare(
190190
f"Comparing {len(comparison)} runs across {len(metric_names)} metrics\n"
191191
)
192192

193+
max_run_name_w = 40
194+
for e in comparison:
195+
if len(e["run"]) > max_run_name_w:
196+
e["run"] = e["run"][: max_run_name_w - 1] + "…"
197+
193198
run_w = max((len(e["run"]) for e in comparison), default=3)
194199
run_w = max(run_w, 3)
195200
status_w = 10

0 commit comments

Comments
 (0)