Skip to content

Commit 17c3f6f

Browse files
committed
simplify tests
1 parent 8f03ab3 commit 17c3f6f

3 files changed

Lines changed: 137 additions & 269 deletions

File tree

autonomous-experiments/test_harness/agent_runner.py

Lines changed: 14 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
Acts as an autonomous agent that:
55
1. Launches simulated training via subprocess
66
2. Polls alerts via trackio CLI
7-
3. Reads results via trackio CLI
8-
4. Decides next hyperparameters based on results
9-
5. Iterates for N rounds
7+
3. Decides next hyperparameters based on results
8+
4. Iterates for N rounds
109
"""
1110

1211
import argparse
@@ -70,202 +69,52 @@ def get_alerts(project, run_name=None, since=None):
7069
return []
7170

7271

73-
def get_run_metric(project, run_name, metric_name):
74-
result = run_cli(
75-
[
76-
"get",
77-
"metric",
78-
"--project",
79-
project,
80-
"--run",
81-
run_name,
82-
"--metric",
83-
metric_name,
84-
]
85-
)
86-
if result and "values" in result:
87-
return result["values"]
88-
return []
89-
90-
91-
def get_runs(project):
92-
result = run_cli(["list", "runs", "--project", project])
93-
if result and "runs" in result:
94-
return result["runs"]
95-
return []
96-
97-
98-
def get_final_metric(project, run_name, metric_name):
99-
values = get_run_metric(project, run_name, metric_name)
100-
if values:
101-
return values[-1]["value"]
102-
return None
103-
104-
105-
def find_best_run(project, metric_name, minimize=True):
106-
runs = get_runs(project)
107-
best_run = None
108-
best_value = float("inf") if minimize else float("-inf")
109-
110-
for run_name in runs:
111-
val = get_final_metric(project, run_name, metric_name)
112-
if val is None:
113-
continue
114-
if minimize and val < best_value:
115-
best_value = val
116-
best_run = run_name
117-
elif not minimize and val > best_value:
118-
best_value = val
119-
best_run = run_name
120-
121-
return best_run, best_value
122-
123-
124-
def experiment_lr_search(project):
125-
print("\n" + "=" * 60)
126-
print("EXPERIMENT 1: Learning Rate Search")
127-
print("Goal: Find the best learning rate from a sequence")
128-
print("=" * 60)
129-
130-
learning_rates = [1.0, 0.5, 0.1, 0.05, 0.01, 0.005, 0.001]
131-
commands_issued = 0
132-
133-
for lr in learning_rates:
134-
run_name = f"lr-{lr}"
135-
run_training(project, run_name, steps=300, lr=lr, seed=42)
136-
137-
alerts = get_alerts(project, run_name)
138-
commands_issued += 1
139-
error_alerts = [a for a in alerts if a.get("level") == "error"]
140-
141-
if error_alerts:
142-
print(f" [AGENT] LR {lr} caused errors: {error_alerts[0]['title']}")
143-
continue
144-
145-
val_loss = get_final_metric(project, run_name, "val/loss")
146-
commands_issued += 1
147-
print(f" [AGENT] LR {lr} -> final val_loss: {val_loss}")
148-
149-
best_run, best_val = find_best_run(project, "val/loss", minimize=True)
150-
commands_issued += len(get_runs(project)) + 1
151-
print(f"\n[AGENT DECISION] Best run: {best_run} with val_loss={best_val:.4f}")
152-
print(f"[METRICS] Total CLI commands to find best: {commands_issued}")
153-
print("[METRICS] (With trackio best, this would be 1 command)")
154-
return {"best_run": best_run, "best_value": best_val, "commands": commands_issued}
155-
156-
157-
def experiment_architecture_search(project):
158-
print("\n" + "=" * 60)
159-
print("EXPERIMENT 2: Architecture Search")
160-
print("Goal: Compare different model depths and find best architecture")
161-
print("=" * 60)
162-
163-
configs = [
164-
{"depth": 2, "lr": 0.01, "batch_size": 32},
165-
{"depth": 4, "lr": 0.01, "batch_size": 32},
166-
{"depth": 6, "lr": 0.01, "batch_size": 32},
167-
{"depth": 8, "lr": 0.01, "batch_size": 32},
168-
{"depth": 12, "lr": 0.01, "batch_size": 32},
169-
{"depth": 4, "lr": 0.01, "batch_size": 64},
170-
{"depth": 6, "lr": 0.005, "batch_size": 64},
171-
]
172-
commands_issued = 0
173-
174-
for cfg in configs:
175-
run_name = f"arch-d{cfg['depth']}-bs{cfg['batch_size']}-lr{cfg['lr']}"
176-
run_training(project, run_name, steps=300, seed=42, **cfg)
177-
178-
runs = get_runs(project)
179-
commands_issued += 1
180-
181-
comparison = []
182-
for run_name in runs:
183-
val_loss = get_final_metric(project, run_name, "val/loss")
184-
accuracy = get_final_metric(project, run_name, "accuracy")
185-
commands_issued += 2
186-
comparison.append(
187-
{
188-
"run": run_name,
189-
"val_loss": val_loss,
190-
"accuracy": accuracy,
191-
}
192-
)
193-
194-
comparison.sort(
195-
key=lambda x: x["val_loss"] if x["val_loss"] is not None else float("inf")
196-
)
197-
198-
print("\n[AGENT] Run comparison (sorted by val_loss):")
199-
for entry in comparison:
200-
print(
201-
f" {entry['run']}: val_loss={entry['val_loss']}, accuracy={entry['accuracy']}"
202-
)
203-
204-
best = comparison[0]
205-
print(f"\n[AGENT DECISION] Best architecture: {best['run']}")
206-
print(f"[METRICS] Total CLI commands for comparison: {commands_issued}")
207-
print("[METRICS] (With trackio compare, this would be 1 command)")
208-
return {"best_run": best["run"], "commands": commands_issued}
209-
210-
21172
def experiment_failure_recovery(project):
21273
print("\n" + "=" * 60)
213-
print("EXPERIMENT 3: Failure Recovery")
74+
print("EXPERIMENT: Failure Recovery")
21475
print("Goal: Detect crashes and restart with adjusted parameters")
21576
print("=" * 60)
21677

21778
attempts = []
21879
lr = 1.0
21980
max_attempts = 5
220-
commands_issued = 0
22181

22282
for attempt in range(max_attempts):
22383
run_name = f"attempt-{attempt}-lr{lr}"
22484
returncode = run_training(project, run_name, steps=500, lr=lr, seed=42)
22585

22686
alerts = get_alerts(project, run_name)
227-
commands_issued += 1
228-
22987
error_alerts = [a for a in alerts if a.get("level") == "error"]
23088

23189
if returncode != 0 or error_alerts:
232-
if error_alerts:
233-
error_msg = error_alerts[0]["title"]
234-
else:
235-
error_msg = "non-zero exit code"
90+
error_msg = error_alerts[0]["title"] if error_alerts else "non-zero exit code"
23691
print(f" [AGENT] Attempt {attempt} failed: {error_msg}")
237-
print(
238-
" [AGENT] NOTE: Cannot determine run status (running vs crashed) from CLI"
239-
)
24092
lr *= 0.1
24193
print(f" [AGENT] Reducing LR to {lr}")
24294
attempts.append({"run": run_name, "status": "failed", "lr": lr * 10})
24395
else:
244-
val_loss = get_final_metric(project, run_name, "val/loss")
245-
commands_issued += 1
246-
print(f" [AGENT] Attempt {attempt} succeeded! val_loss={val_loss}")
247-
attempts.append(
248-
{"run": run_name, "status": "success", "val_loss": val_loss}
96+
result = run_cli(
97+
["get", "metric", "--project", project, "--run", run_name, "--metric", "val/loss"]
24998
)
99+
val_loss = result["values"][-1]["value"] if result and result.get("values") else None
100+
print(f" [AGENT] Attempt {attempt} succeeded! val_loss={val_loss}")
101+
attempts.append({"run": run_name, "status": "success", "val_loss": val_loss})
250102
break
251103

252104
print("\n[AGENT] Recovery history:")
253105
for a in attempts:
254106
print(f" {a}")
255-
print(f"[METRICS] Total CLI commands: {commands_issued}")
256-
print("[METRICS] Gap: No run status tracking - must infer from alerts + exit code")
257-
return {"attempts": len(attempts), "commands": commands_issued}
107+
return {"attempts": len(attempts)}
258108

259109

260110
def experiment_long_monitoring(project):
261111
print("\n" + "=" * 60)
262-
print("EXPERIMENT 4: Long-Running Monitoring")
112+
print("EXPERIMENT: Long-Running Monitoring")
263113
print("Goal: Test alert polling with --since during active training")
264114
print("=" * 60)
265115

266116
run_name = "long-run"
267117
since = datetime.now(timezone.utc).isoformat()
268-
commands_issued = 0
269118

270119
cmd = [
271120
sys.executable,
@@ -287,25 +136,18 @@ def experiment_long_monitoring(project):
287136
]
288137

289138
print(" [AGENT] Starting long training run in background...")
290-
proc = subprocess.Popen(
291-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
292-
)
139+
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
293140

294-
poll_count = 0
295141
all_alerts = []
296142

297143
while proc.poll() is None:
298144
time.sleep(0.5)
299145
alerts = get_alerts(project, run_name, since=since)
300-
commands_issued += 1
301-
poll_count += 1
302146

303147
new_alerts = [a for a in alerts if a not in all_alerts]
304148
if new_alerts:
305149
for alert in new_alerts:
306-
print(
307-
f" [AGENT] New alert: [{alert.get('level', '?')}] {alert.get('title', '?')}"
308-
)
150+
print(f" [AGENT] New alert: [{alert.get('level', '?')}] {alert.get('title', '?')}")
309151
all_alerts.append(alert)
310152
since = datetime.now(timezone.utc).isoformat()
311153

@@ -314,89 +156,13 @@ def experiment_long_monitoring(project):
314156
print(f" [AGENT] stdout: {stdout.strip()}")
315157

316158
final_alerts = get_alerts(project, run_name)
317-
commands_issued += 1
318-
319159
print(f"\n[AGENT] Total alerts captured: {len(final_alerts)}")
320-
print(f"[METRICS] Poll count: {poll_count}")
321-
print(f"[METRICS] Total CLI commands: {commands_issued}")
322-
return {
323-
"poll_count": poll_count,
324-
"alerts": len(final_alerts),
325-
"commands": commands_issued,
326-
}
327-
328-
329-
def experiment_multi_objective(project):
330-
print("\n" + "=" * 60)
331-
print("EXPERIMENT 5: Multi-Objective Optimization")
332-
print("Goal: Optimize for both val_loss AND accuracy simultaneously")
333-
print("=" * 60)
334-
335-
configs = [
336-
{"lr": 0.001, "depth": 4, "batch_size": 16},
337-
{"lr": 0.005, "depth": 6, "batch_size": 32},
338-
{"lr": 0.01, "depth": 6, "batch_size": 32},
339-
{"lr": 0.01, "depth": 8, "batch_size": 64},
340-
{"lr": 0.05, "depth": 4, "batch_size": 32},
341-
]
342-
commands_issued = 0
343-
344-
for cfg in configs:
345-
run_name = f"multi-d{cfg['depth']}-lr{cfg['lr']}-bs{cfg['batch_size']}"
346-
run_training(project, run_name, steps=300, seed=42, **cfg)
347-
348-
runs = get_runs(project)
349-
commands_issued += 1
350-
351-
results = []
352-
for run_name in runs:
353-
val_loss = get_final_metric(project, run_name, "val/loss")
354-
accuracy = get_final_metric(project, run_name, "accuracy")
355-
commands_issued += 2
356-
results.append({"run": run_name, "val_loss": val_loss, "accuracy": accuracy})
357-
358-
print("\n[AGENT] Multi-objective results:")
359-
for r in sorted(results, key=lambda x: (x["val_loss"] or float("inf"))):
360-
print(f" {r['run']}: val_loss={r['val_loss']}, accuracy={r['accuracy']}")
361-
362-
pareto_front = []
363-
for r in results:
364-
if r["val_loss"] is None or r["accuracy"] is None:
365-
continue
366-
dominated = False
367-
for other in results:
368-
if other["val_loss"] is None or other["accuracy"] is None:
369-
continue
370-
if (
371-
other["val_loss"] <= r["val_loss"]
372-
and other["accuracy"] >= r["accuracy"]
373-
):
374-
if (
375-
other["val_loss"] < r["val_loss"]
376-
or other["accuracy"] > r["accuracy"]
377-
):
378-
dominated = True
379-
break
380-
if not dominated:
381-
pareto_front.append(r)
382-
383-
print("\n[AGENT] Pareto-optimal runs:")
384-
for r in pareto_front:
385-
print(f" {r['run']}: val_loss={r['val_loss']}, accuracy={r['accuracy']}")
386-
387-
print(f"\n[METRICS] Total CLI commands: {commands_issued}")
388-
print(
389-
"[METRICS] (With trackio compare --metrics val/loss,accuracy, this would be 1)"
390-
)
391-
return {"pareto_front": pareto_front, "commands": commands_issued}
160+
return {"alerts": len(final_alerts)}
392161

393162

394163
EXPERIMENTS = {
395-
"lr_search": experiment_lr_search,
396-
"architecture_search": experiment_architecture_search,
397164
"failure_recovery": experiment_failure_recovery,
398165
"long_monitoring": experiment_long_monitoring,
399-
"multi_objective": experiment_multi_objective,
400166
"all": None,
401167
}
402168

@@ -422,14 +188,12 @@ def main():
422188
experiments = [args.experiment]
423189

424190
results = {}
425-
total_commands = 0
426191

427192
for exp_name in experiments:
428193
project = f"{args.project_prefix}-{exp_name}"
429194
try:
430195
result = EXPERIMENTS[exp_name](project)
431196
results[exp_name] = result
432-
total_commands += result.get("commands", 0)
433197
except Exception as e:
434198
print(f"\n[ERROR] Experiment {exp_name} failed: {e}")
435199
results[exp_name] = {"error": str(e)}
@@ -441,7 +205,6 @@ def main():
441205
print(f"\n{name}:")
442206
for k, v in result.items():
443207
print(f" {k}: {v}")
444-
print(f"\nTotal CLI commands across all experiments: {total_commands}")
445208

446209

447210
if __name__ == "__main__":

0 commit comments

Comments
 (0)