Skip to content

Commit 2704679

Browse files
committed
support custom conditions
1 parent af8331c commit 2704679

4 files changed

Lines changed: 136 additions & 6 deletions

File tree

docs/source/alerts.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,34 @@ for alert in alerts:
152152
| `AlertReason.MIN_EXCEEDED` | `"min_exceeded"` | Metric dropped below `min_value` |
153153
| `AlertReason.SPIKE` | `"spike"` | Spike detected vs. recent average |
154154
| `AlertReason.STAGNATION` | `"stagnation"` | No improvement for `patience` steps |
155+
| `AlertReason.CUSTOM` | `"custom"` | Custom condition returned `True` |
156+
157+
### Custom Conditions
158+
159+
Pass `fn` to [`watch`] to define your own condition. The function receives `(value, step)` and should return `True` to fire a default WARN alert, a list of alert dicts for full control, or a falsy value for no alert:
160+
161+
```python
162+
def check_divergence(value, step):
163+
if value > 50.0:
164+
return [
165+
{
166+
"title": "Loss diverged",
167+
"level": trackio.AlertLevel.ERROR,
168+
"text": f"val_loss={value:.2f} at step {step}",
169+
"data": {"reason": "diverged", "threshold": 50.0, "value": value},
170+
"stop": True,
171+
}
172+
]
173+
return None
174+
175+
trackio.watch("val/loss", fn=check_divergence)
176+
```
177+
178+
Include `"stop": True` in a returned dict to set `should_stop()` to `True`. Custom conditions can be combined with built-in ones — both run independently on every `log()` call:
179+
180+
```python
181+
trackio.watch("train/loss", nan=True, fn=check_divergence)
182+
```
155183

156184
---
157185

tests/e2e-local/test_watchers.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from trackio.watchers import AlertReason, MetricWatcher, WatcherManager
1+
from trackio.watchers import (
2+
AlertLevel,
3+
AlertReason,
4+
CustomMetricWatcher,
5+
MetricWatcher,
6+
WatcherManager,
7+
)
28

39

410
def test_nan_inf_triggers_stop():
@@ -108,3 +114,50 @@ def test_manager_should_stop_and_clear():
108114
def test_non_numeric_ignored():
109115
w = MetricWatcher("loss", max_value=10.0)
110116
assert len(w.check("not a number", step=0)) == 0
117+
118+
119+
def test_custom_watcher_bool_return():
120+
w = CustomMetricWatcher("loss", fn=lambda v, s: v > 10.0)
121+
assert len(w.check(5.0, step=0)) == 0
122+
alerts = w.check(15.0, step=1)
123+
assert len(alerts) == 1
124+
assert alerts[0]["level"] == AlertLevel.WARN
125+
assert alerts[0]["data"]["reason"] == AlertReason.CUSTOM
126+
assert not w.should_stop
127+
128+
129+
def test_custom_watcher_dict_return_with_stop():
130+
def fn(value, step):
131+
if value > 10.0:
132+
return [
133+
{
134+
"title": "too high",
135+
"level": AlertLevel.ERROR,
136+
"stop": True,
137+
"data": {},
138+
}
139+
]
140+
return None
141+
142+
w = CustomMetricWatcher("loss", fn=fn)
143+
assert len(w.check(5.0, step=0)) == 0
144+
assert not w.should_stop
145+
alerts = w.check(15.0, step=1)
146+
assert len(alerts) == 1
147+
assert alerts[0]["title"] == "too high"
148+
assert w.should_stop
149+
150+
151+
def test_custom_watcher_none_return():
152+
w = CustomMetricWatcher("loss", fn=lambda v, s: None)
153+
assert len(w.check(99.0, step=0)) == 0
154+
155+
156+
def test_manager_mixes_builtin_and_custom():
157+
mgr = WatcherManager()
158+
mgr.add(MetricWatcher("loss", max_value=20.0))
159+
mgr.add(CustomMetricWatcher("loss", fn=lambda v, s: v > 10.0))
160+
alerts = mgr.check({"loss": 15.0}, step=0)
161+
assert len(alerts) == 1
162+
alerts2 = mgr.check({"loss": 25.0}, step=1)
163+
assert len(alerts2) == 2

trackio/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
import webbrowser
99
from pathlib import Path
10-
from typing import Any
10+
from typing import Any, Callable
1111

1212
import huggingface_hub
1313
from gradio_client import handle_file
@@ -41,7 +41,12 @@
4141
from trackio.trace import Trace
4242
from trackio.typehints import UploadEntry
4343
from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR, _emit_nonfatal_warning
44-
from trackio.watchers import AlertReason, MetricWatcher, WatcherManager
44+
from trackio.watchers import (
45+
AlertReason,
46+
CustomMetricWatcher,
47+
MetricWatcher,
48+
WatcherManager,
49+
)
4550

4651
logging.getLogger("httpx").setLevel(logging.WARNING)
4752

@@ -843,6 +848,7 @@ def watch(
843848
min_value: float | None = None,
844849
window: int = 5,
845850
mode: str = "min",
851+
fn: Callable | None = None,
846852
) -> None:
847853
"""
848854
Register a metric watcher that automatically fires alerts when conditions
@@ -872,6 +878,12 @@ def watch(
872878
mode (`str`, *optional*, defaults to ``"min"``):
873879
Whether lower (``"min"``) or higher (``"max"``) values are better.
874880
Affects patience-based stagnation detection.
881+
fn (`Callable[[float, int | None], bool | list[dict] | None]`, *optional*):
882+
A custom condition called as ``fn(value, step)`` on every
883+
``trackio.log()`` call. Return ``True`` to fire a default WARN
884+
alert, a list of alert dicts for full control, or a falsy value
885+
for no alert. Include ``"stop": True`` in a returned dict to
886+
also set ``should_stop()`` to ``True``.
875887
"""
876888
watcher = MetricWatcher(
877889
metric_name=metric,
@@ -885,6 +897,8 @@ def watch(
885897
mode=mode,
886898
)
887899
_watcher_manager.add(watcher)
900+
if fn is not None:
901+
_watcher_manager.add(CustomMetricWatcher(metric, fn))
888902

889903

890904
def should_stop() -> bool:

trackio/watchers.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import math
2525
from collections import deque
26-
from typing import Literal
26+
from typing import Callable, Literal
2727

2828
from trackio.alerts import AlertLevel
2929

@@ -43,6 +43,7 @@ class AlertReason:
4343
MIN_EXCEEDED = "min_exceeded"
4444
SPIKE = "spike"
4545
STAGNATION = "stagnation"
46+
CUSTOM = "custom"
4647

4748

4849
class MetricWatcher:
@@ -212,11 +213,45 @@ def should_stop(self) -> bool:
212213
return self._triggered_stop
213214

214215

216+
class CustomMetricWatcher:
217+
def __init__(self, metric_name: str, fn: Callable):
218+
self.metric_name = metric_name
219+
self._fn = fn
220+
self._triggered_stop = False
221+
222+
def check(self, value, step: int | None = None) -> list[dict]:
223+
result = self._fn(value, step)
224+
if not result:
225+
return []
226+
if result is True:
227+
return [
228+
{
229+
"title": f"Custom condition triggered for {self.metric_name}",
230+
"level": AlertLevel.WARN,
231+
"data": {
232+
"metric": self.metric_name,
233+
"value": value,
234+
"step": step,
235+
"reason": AlertReason.CUSTOM,
236+
},
237+
}
238+
]
239+
alerts = list(result)
240+
for a in alerts:
241+
if a.get("stop"):
242+
self._triggered_stop = True
243+
return alerts
244+
245+
@property
246+
def should_stop(self) -> bool:
247+
return self._triggered_stop
248+
249+
215250
class WatcherManager:
216251
def __init__(self):
217-
self._watchers: list[MetricWatcher] = []
252+
self._watchers: list[MetricWatcher | CustomMetricWatcher] = []
218253

219-
def add(self, watcher: MetricWatcher):
254+
def add(self, watcher: MetricWatcher | CustomMetricWatcher):
220255
self._watchers.append(watcher)
221256

222257
def check(self, metrics: dict, step: int | None = None) -> list[dict]:

0 commit comments

Comments
 (0)