Skip to content

Commit e58200b

Browse files
committed
fix test harnesses
1 parent 75f93f6 commit e58200b

2 files changed

Lines changed: 25 additions & 83 deletions

File tree

autonomous-experiments/test_harness/agent_runner.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
import subprocess
1414
import sys
1515
import time
16-
from datetime import datetime, timezone
1716
from pathlib import Path
1817

1918
SIMULATOR = str(Path(__file__).parent / "simulate_training.py")
2019

2120

2221
def run_cli(args_list):
2322
result = subprocess.run(
24-
["trackio"] + args_list + ["--json"],
23+
[sys.executable, "-m", "trackio.cli", *args_list, "--json"],
2524
capture_output=True,
2625
text=True,
2726
)
@@ -57,12 +56,10 @@ def run_training(project, run_name, **kwargs):
5756
return result.returncode
5857

5958

60-
def get_alerts(project, run_name=None, since=None):
59+
def get_alerts(project, run_name=None):
6160
args = ["list", "alerts", "--project", project]
6261
if run_name:
6362
args.extend(["--run", run_name])
64-
if since:
65-
args.extend(["--since", since])
6663
result = run_cli(args)
6764
if result and "alerts" in result:
6865
return result["alerts"]
@@ -91,9 +88,10 @@ def experiment_failure_recovery(project):
9188
error_alerts[0]["title"] if error_alerts else "non-zero exit code"
9289
)
9390
print(f" [AGENT] Attempt {attempt} failed: {error_msg}")
91+
prev_lr = lr
9492
lr *= 0.1
9593
print(f" [AGENT] Reducing LR to {lr}")
96-
attempts.append({"run": run_name, "status": "failed", "lr": lr * 10})
94+
attempts.append({"run": run_name, "status": "failed", "lr": prev_lr})
9795
else:
9896
result = run_cli(
9997
[
@@ -127,11 +125,10 @@ def experiment_failure_recovery(project):
127125
def experiment_long_monitoring(project):
128126
print("\n" + "=" * 60)
129127
print("EXPERIMENT: Long-Running Monitoring")
130-
print("Goal: Test alert polling with --since during active training")
128+
print("Goal: Test alert polling with alert_id dedup during active training")
131129
print("=" * 60)
132130

133131
run_name = "long-run"
134-
since = datetime.now(timezone.utc).isoformat()
135132

136133
cmd = [
137134
sys.executable,
@@ -154,23 +151,21 @@ def experiment_long_monitoring(project):
154151

155152
print(" [AGENT] Starting long training run in background...")
156153
proc = subprocess.Popen(
157-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
154+
cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True
158155
)
159156

160-
all_alerts = []
157+
seen_ids: set[str] = set()
161158

162159
while proc.poll() is None:
163160
time.sleep(0.5)
164-
alerts = get_alerts(project, run_name, since=since)
165-
166-
new_alerts = [a for a in alerts if a not in all_alerts]
167-
if new_alerts:
168-
for alert in new_alerts:
169-
print(
170-
f" [AGENT] New alert: [{alert.get('level', '?')}] {alert.get('title', '?')}"
171-
)
172-
all_alerts.append(alert)
173-
since = datetime.now(timezone.utc).isoformat()
161+
alerts = get_alerts(project, run_name)
162+
new_alerts = [a for a in alerts if a.get("alert_id") not in seen_ids]
163+
for alert in new_alerts:
164+
print(
165+
f" [AGENT] New alert: [{alert.get('level', '?')}] {alert.get('title', '?')}"
166+
)
167+
if alert.get("alert_id") is not None:
168+
seen_ids.add(alert["alert_id"])
174169

175170
stdout, _ = proc.communicate()
176171
print(f" [AGENT] Training finished. Exit code: {proc.returncode}")
@@ -184,15 +179,14 @@ def experiment_long_monitoring(project):
184179
EXPERIMENTS = {
185180
"failure_recovery": experiment_failure_recovery,
186181
"long_monitoring": experiment_long_monitoring,
187-
"all": None,
188182
}
189183

190184

191185
def main():
192186
parser = argparse.ArgumentParser(description="Agent test runner for autonomous ML")
193187
parser.add_argument(
194188
"--experiment",
195-
choices=list(EXPERIMENTS.keys()),
189+
choices=[*EXPERIMENTS.keys(), "all"],
196190
default="all",
197191
help="Which experiment to run",
198192
)
@@ -203,10 +197,7 @@ def main():
203197
)
204198
args = parser.parse_args()
205199

206-
if args.experiment == "all":
207-
experiments = [k for k in EXPERIMENTS if k != "all"]
208-
else:
209-
experiments = [args.experiment]
200+
experiments = list(EXPERIMENTS) if args.experiment == "all" else [args.experiment]
210201

211202
results = {}
212203

autonomous-experiments/test_harness/simulate_training.py

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import argparse
1010
import math
1111
import random
12-
import sys
1312
import time
1413

1514
import trackio
@@ -90,70 +89,19 @@ def main():
9089

9190
trackio.init(project=args.project, name=args.run_name, config=config)
9291

93-
best_val_loss = float("inf")
94-
stagnation_count = 0
92+
trackio.watch("train/loss", nan=True, max_value=10.0, spike_factor=3.0, window=10)
9593

9694
for step in range(args.steps):
9795
train_loss = simulate_loss(
9896
step, args.steps, args.lr, args.depth, args.batch_size
9997
)
10098

101-
if args.spike_at_step and step == args.spike_at_step:
99+
if args.spike_at_step is not None and step == args.spike_at_step:
102100
train_loss *= 10.0
103-
trackio.alert(
104-
"Loss spike detected",
105-
text=f"Loss spiked to {train_loss:.4f} at step {step}",
106-
level=AlertLevel.WARN,
107-
)
108-
109-
if math.isnan(train_loss) or math.isinf(train_loss):
110-
trackio.alert(
111-
"NaN/Inf loss detected",
112-
text=f"Loss became {train_loss} at step {step}. Training is diverging.",
113-
level=AlertLevel.ERROR,
114-
)
115-
trackio.log({"train/loss": train_loss, "val/loss": train_loss}, step=step)
116-
trackio.finish()
117-
print(f"TERMINATED EARLY: NaN/Inf loss at step {step}")
118-
sys.exit(1)
119101

120102
val_loss = simulate_val_loss(train_loss, step, args.steps, args.depth)
121103
accuracy = simulate_accuracy(val_loss)
122104

123-
if val_loss < best_val_loss:
124-
best_val_loss = val_loss
125-
stagnation_count = 0
126-
else:
127-
stagnation_count += 1
128-
129-
if train_loss > 10.0 and step > 50:
130-
trackio.alert(
131-
"Training diverging",
132-
text=f"Loss {train_loss:.4f} is very high at step {step}. Learning rate may be too high.",
133-
level=AlertLevel.ERROR,
134-
)
135-
trackio.log(
136-
{
137-
"train/loss": round(train_loss, 4),
138-
"val/loss": round(val_loss, 4),
139-
"accuracy": round(accuracy, 4),
140-
"best_val_loss": round(best_val_loss, 4),
141-
"lr": args.lr,
142-
},
143-
step=step,
144-
)
145-
trackio.finish()
146-
print(f"TERMINATED EARLY: diverging at step {step}")
147-
sys.exit(1)
148-
149-
if stagnation_count >= 100 and step > 100:
150-
trackio.alert(
151-
"Training stagnated",
152-
text=f"Val loss has not improved for {stagnation_count} steps. Best: {best_val_loss:.4f}",
153-
level=AlertLevel.WARN,
154-
)
155-
stagnation_count = 0
156-
157105
if val_loss > train_loss * 1.5 and step > args.steps * 0.5:
158106
trackio.alert(
159107
"Overfitting detected",
@@ -166,22 +114,25 @@ def main():
166114
"train/loss": round(train_loss, 4),
167115
"val/loss": round(val_loss, 4),
168116
"accuracy": round(accuracy, 4),
169-
"best_val_loss": round(best_val_loss, 4),
170117
"lr": args.lr,
171118
},
172119
step=step,
173120
)
174121

122+
if trackio.should_stop():
123+
print(f"TERMINATED EARLY: watcher triggered stop at step {step}")
124+
break
125+
175126
if args.sleep > 0:
176127
time.sleep(args.sleep)
177128

178129
trackio.alert(
179130
"Training complete",
180-
text=f"Finished {args.steps} steps. Best val loss: {best_val_loss:.4f}",
131+
text=f"Finished at step {step}.",
181132
level=AlertLevel.INFO,
182133
)
183134
trackio.finish()
184-
print(f"Training complete. Best val loss: {best_val_loss:.4f}")
135+
print(f"Training complete. Final step: {step}")
185136

186137

187138
if __name__ == "__main__":

0 commit comments

Comments
 (0)