Skip to content

Commit db8f78f

Browse files
committed
refactors
1 parent 66a4eff commit db8f78f

9 files changed

Lines changed: 302 additions & 61 deletions

File tree

tests/e2e-local/test_cli_agent_commands.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,54 @@ def seeded_dir():
6868
mm.MEDIA_DIR, mu.MEDIA_DIR, tu.MEDIA_DIR, ss.MEDIA_DIR = orig_media
6969

7070

71+
@pytest.fixture(scope="module")
72+
def seeded_dir_with_unfinished():
73+
"""Fixture with two finished runs and one still-running run."""
74+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
75+
import trackio.media.media as mm
76+
import trackio.media.utils as mu
77+
import trackio.sqlite_storage as ss
78+
import trackio.utils as tu
79+
80+
orig_trackio = ss.TRACKIO_DIR
81+
orig_media = [mm.MEDIA_DIR, mu.MEDIA_DIR, tu.MEDIA_DIR, ss.MEDIA_DIR]
82+
ss.TRACKIO_DIR = Path(tmpdir)
83+
mm.MEDIA_DIR = mu.MEDIA_DIR = tu.MEDIA_DIR = ss.MEDIA_DIR = (
84+
Path(tmpdir) / "media"
85+
)
86+
87+
context_vars.current_run.set(None)
88+
context_vars.current_project.set(None)
89+
context_vars.current_server.set(None)
90+
context_vars.current_space_id.set(None)
91+
92+
proj = "filter_test"
93+
94+
trackio.init(project=proj, name="done-run", config={"lr": 0.01})
95+
for step in range(5):
96+
trackio.log({"val/loss": 1.0 - step * 0.1}, step=step)
97+
trackio.finish()
98+
99+
trackio.init(project=proj, name="also-done", config={"lr": 0.1})
100+
for step in range(5):
101+
trackio.log({"val/loss": 2.0 - step * 0.1}, step=step)
102+
trackio.finish()
103+
104+
trackio.init(project=proj, name="still-running", config={"lr": 1.0})
105+
for step in range(5):
106+
trackio.log({"val/loss": 5.0 + step * 0.5}, step=step)
107+
108+
context_vars.current_run.set(None)
109+
context_vars.current_project.set(None)
110+
context_vars.current_server.set(None)
111+
context_vars.current_space_id.set(None)
112+
113+
yield (tmpdir, proj)
114+
115+
ss.TRACKIO_DIR = orig_trackio
116+
mm.MEDIA_DIR, mu.MEDIA_DIR, tu.MEDIA_DIR, ss.MEDIA_DIR = orig_media
117+
118+
71119
def _cli(args, env_dir):
72120
env = os.environ.copy()
73121
env["TRACKIO_DIR"] = env_dir
@@ -108,6 +156,29 @@ def test_best(seeded_dir):
108156
assert json.loads(r2.stdout)["best_run"] == "run-lr0.01"
109157

110158

159+
def test_best_excludes_unfinished_by_default(seeded_dir_with_unfinished):
160+
tmpdir, proj = seeded_dir_with_unfinished
161+
r = _cli(["best", "--project", proj, "--metric", "val/loss", "--json"], tmpdir)
162+
assert r.returncode == 0
163+
data = json.loads(r.stdout)
164+
run_names = [e["run"] for e in data["ranking"]]
165+
assert "still-running" not in run_names
166+
assert len(run_names) == 2
167+
168+
169+
def test_best_include_all(seeded_dir_with_unfinished):
170+
tmpdir, proj = seeded_dir_with_unfinished
171+
r = _cli(
172+
["best", "--project", proj, "--metric", "val/loss", "--include-all", "--json"],
173+
tmpdir,
174+
)
175+
assert r.returncode == 0
176+
data = json.loads(r.stdout)
177+
run_names = [e["run"] for e in data["ranking"]]
178+
assert "still-running" in run_names
179+
assert len(run_names) == 3
180+
181+
111182
def test_compare(seeded_dir):
112183
r = _cli(
113184
["compare", "--project", PROJECT, "--metrics", "val/loss,accuracy", "--json"],
@@ -136,6 +207,37 @@ def test_compare(seeded_dir):
136207
assert len(json.loads(r2.stdout)["runs"]) == 2
137208

138209

210+
def test_compare_excludes_unfinished_by_default(seeded_dir_with_unfinished):
211+
tmpdir, proj = seeded_dir_with_unfinished
212+
r = _cli(["compare", "--project", proj, "--metrics", "val/loss", "--json"], tmpdir)
213+
assert r.returncode == 0
214+
data = json.loads(r.stdout)
215+
run_names = [e["run"] for e in data["runs"]]
216+
assert "still-running" not in run_names
217+
assert len(run_names) == 2
218+
219+
220+
def test_compare_include_all(seeded_dir_with_unfinished):
221+
tmpdir, proj = seeded_dir_with_unfinished
222+
r = _cli(
223+
[
224+
"compare",
225+
"--project",
226+
proj,
227+
"--metrics",
228+
"val/loss",
229+
"--include-all",
230+
"--json",
231+
],
232+
tmpdir,
233+
)
234+
assert r.returncode == 0
235+
data = json.loads(r.stdout)
236+
run_names = [e["run"] for e in data["runs"]]
237+
assert "still-running" in run_names
238+
assert len(run_names) == 3
239+
240+
139241
def test_summary(seeded_dir):
140242
r = _cli(
141243
["summary", "--project", PROJECT, "--metric", "val/loss", "--json"], seeded_dir
@@ -155,6 +257,20 @@ def test_summary(seeded_dir):
155257
} <= run_entry.keys()
156258

157259

260+
def test_list_runs_json_includes_status(seeded_dir):
261+
r = _cli(["list", "runs", "--project", PROJECT, "--json"], seeded_dir)
262+
assert r.returncode == 0
263+
data = json.loads(r.stdout)
264+
assert "runs" in data
265+
for entry in data["runs"]:
266+
assert "name" in entry
267+
assert "status" in entry
268+
statuses = {e["name"]: e["status"] for e in data["runs"]}
269+
assert statuses.get("run-lr0.01") == "finished"
270+
assert statuses.get("run-lr0.1") == "finished"
271+
assert statuses.get("run-lr1.0") == "finished"
272+
273+
158274
def test_best_error_cases(seeded_dir):
159275
assert (
160276
_cli(

tests/e2e-local/test_run_status.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,33 @@ def test_api_run_status(temp_dir):
4141
assert run.status == "finished"
4242

4343

44+
def test_api_run_final_metrics(temp_dir):
45+
trackio.init(project="final_metrics_test", name="run1")
46+
trackio.log({"loss": 1.0, "acc": 0.5}, step=0)
47+
trackio.log({"loss": 0.5, "acc": 0.8}, step=1)
48+
trackio.finish()
49+
50+
run = trackio.Api().runs("final_metrics_test")[0]
51+
fm = run.final_metrics
52+
assert abs(fm["loss"] - 0.5) < 1e-6
53+
assert abs(fm["acc"] - 0.8) < 1e-6
54+
55+
56+
def test_api_run_history_with_metric_filter(temp_dir):
57+
trackio.init(project="history_test", name="run1")
58+
for step in range(5):
59+
trackio.log({"loss": 1.0 - step * 0.1, "acc": step * 0.1}, step=step)
60+
trackio.finish()
61+
62+
run = trackio.Api().runs("history_test")[0]
63+
full = run.history()
64+
assert len(full) == 5
65+
66+
loss_history = run.history(metric="loss")
67+
assert len(loss_history) == 5
68+
assert all("value" in entry for entry in loss_history)
69+
70+
4471
def test_status_survives_multiple_runs(temp_dir):
4572
run1 = trackio.init(project="multi_status", name="run1")
4673
trackio.log({"loss": 0.5}, step=0)

tests/e2e-local/test_watchers.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from trackio.watchers import MetricWatcher, WatcherManager
1+
from trackio.watchers import AlertReason, MetricWatcher, WatcherManager
22

33

44
def test_nan_inf_triggers_stop():
55
w = MetricWatcher("loss", nan=True)
66
alerts = w.check(float("nan"), step=10)
77
assert len(alerts) == 1
8-
assert alerts[0]["data"]["reason"] == "nan_inf"
8+
assert alerts[0]["data"]["reason"] == AlertReason.NAN_INF
99
assert w.should_stop
1010

1111
w2 = MetricWatcher("loss", nan=False)
@@ -17,7 +17,7 @@ def test_max_value_with_dedup():
1717
assert len(w.check(5.0, step=0)) == 0
1818
alerts = w.check(15.0, step=1)
1919
assert len(alerts) == 1
20-
assert alerts[0]["data"]["reason"] == "max_exceeded"
20+
assert alerts[0]["data"]["reason"] == AlertReason.MAX_EXCEEDED
2121
assert w.should_stop
2222
assert len(w.check(15.0, step=2)) == 0
2323
w.check(5.0, step=3)
@@ -27,7 +27,9 @@ def test_max_value_with_dedup():
2727
def test_min_value_with_dedup():
2828
w = MetricWatcher("acc", min_value=0.5)
2929
assert len(w.check(0.8, step=0)) == 0
30-
assert len(w.check(0.3, step=1)) == 1
30+
alerts = w.check(0.3, step=1)
31+
assert len(alerts) == 1
32+
assert alerts[0]["data"]["reason"] == AlertReason.MIN_EXCEEDED
3133
assert len(w.check(0.3, step=2)) == 0
3234
w.check(0.8, step=3)
3335
assert len(w.check(0.3, step=4)) == 1
@@ -39,13 +41,23 @@ def test_spike_detection_with_dedup_and_reset():
3941
w.check(1.0, step=i)
4042
alerts = w.check(10.0, step=3)
4143
assert len(alerts) == 1
42-
assert alerts[0]["data"]["reason"] == "spike"
44+
assert alerts[0]["data"]["reason"] == AlertReason.SPIKE
4345
assert len(w.check(10.0, step=4)) == 0
4446
for i in range(3):
4547
w.check(1.0, step=5 + i)
4648
assert len(w.check(10.0, step=8)) == 1
4749

4850

51+
def test_spike_detection_works_for_negative_metrics():
52+
w = MetricWatcher("reward", spike_factor=3.0, window=3)
53+
for i in range(3):
54+
w.check(-1.0, step=i)
55+
assert len(w.check(-1.1, step=3)) == 0
56+
alerts = w.check(2.0, step=4)
57+
assert len(alerts) == 1
58+
assert alerts[0]["data"]["reason"] == AlertReason.SPIKE
59+
60+
4961
def test_patience_min_mode():
5062
w = MetricWatcher("loss", patience=3, mode="min")
5163
w.check(1.0, step=0)
@@ -54,7 +66,7 @@ def test_patience_min_mode():
5466
w.check(0.95, step=3)
5567
alerts = w.check(0.95, step=4)
5668
assert len(alerts) == 1
57-
assert alerts[0]["data"]["reason"] == "stagnation"
69+
assert alerts[0]["data"]["reason"] == AlertReason.STAGNATION
5870
assert w.should_stop
5971
assert len(w.check(0.95, step=5)) == 0
6072

trackio/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from trackio.table import Table
4040
from trackio.typehints import UploadEntry
4141
from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR, _emit_nonfatal_warning
42-
from trackio.watchers import MetricWatcher, WatcherManager
42+
from trackio.watchers import AlertReason, MetricWatcher, WatcherManager
4343

4444
logging.getLogger("httpx").setLevel(logging.WARNING)
4545

@@ -63,6 +63,7 @@ def __repr__(self) -> str:
6363
"watch",
6464
"should_stop",
6565
"AlertLevel",
66+
"AlertReason",
6667
"show",
6768
"sync",
6869
"freeze",
@@ -97,8 +98,6 @@ def _cleanup_current_run():
9798
try:
9899
if not run._finished:
99100
run.finish(status="failed")
100-
else:
101-
run.finish()
102101
except Exception:
103102
pass
104103

trackio/api.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,40 @@ def status(self) -> str | None:
2727
return SQLiteStorage.get_run_status(self.project, self.name, run_id=self.id)
2828

2929
@property
30-
def summary(self) -> dict:
31-
logs = SQLiteStorage.get_logs(self.project, self.name)
32-
final_values = {}
33-
for log_entry in logs:
34-
for key, value in log_entry.items():
35-
if key not in ("timestamp", "step") and isinstance(value, (int, float)):
36-
final_values[key] = value
37-
return final_values
30+
def final_metrics(self) -> dict:
31+
"""Last recorded value for each numeric metric, keyed by metric name."""
32+
metric_names = SQLiteStorage.get_all_metrics_for_run(self.project, self.name)
33+
result = {}
34+
for m in metric_names:
35+
rows = SQLiteStorage.get_final_metric_for_runs(
36+
self.project, m, mode="last", run_names=[self.name], status_filter=None
37+
)
38+
if rows:
39+
result[m] = rows[0]["value"]
40+
return result
3841

3942
def metrics(self) -> list[str]:
4043
return SQLiteStorage.get_all_metrics_for_run(self.project, self.name)
4144

42-
def history(self, metric: str | None = None) -> list[dict]:
43-
if metric is not None:
44-
return SQLiteStorage.get_metric_values(self.project, self.name, metric)
45-
return SQLiteStorage.get_logs(self.project, self.name)
46-
47-
def get_metric(
45+
def history(
4846
self,
49-
name: str,
47+
metric: str | None = None,
5048
step: int | None = None,
5149
around_step: int | None = None,
5250
at_time: str | None = None,
53-
window: int | float | None = None,
51+
window: int | None = None,
5452
) -> list[dict]:
55-
return SQLiteStorage.get_metric_values(
56-
self.project,
57-
self.name,
58-
name,
59-
step=step,
60-
around_step=around_step,
61-
at_time=at_time,
62-
window=window,
63-
)
53+
if metric is not None:
54+
return SQLiteStorage.get_metric_values(
55+
self.project,
56+
self.name,
57+
metric,
58+
step=step,
59+
around_step=around_step,
60+
at_time=at_time,
61+
window=window,
62+
)
63+
return SQLiteStorage.get_logs(self.project, self.name)
6464

6565
def alerts(self, level: str | None = None, since: str | None = None) -> list[dict]:
6666
return SQLiteStorage.get_alerts(

0 commit comments

Comments
 (0)