Skip to content

Commit 34791b4

Browse files
committed
format
1 parent 6122070 commit 34791b4

2 files changed

Lines changed: 90 additions & 15 deletions

File tree

autonomous-experiments/test_harness/agent_runner.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,35 @@ def experiment_failure_recovery(project):
8787
error_alerts = [a for a in alerts if a.get("level") == "error"]
8888

8989
if returncode != 0 or error_alerts:
90-
error_msg = error_alerts[0]["title"] if error_alerts else "non-zero exit code"
90+
error_msg = (
91+
error_alerts[0]["title"] if error_alerts else "non-zero exit code"
92+
)
9193
print(f" [AGENT] Attempt {attempt} failed: {error_msg}")
9294
lr *= 0.1
9395
print(f" [AGENT] Reducing LR to {lr}")
9496
attempts.append({"run": run_name, "status": "failed", "lr": lr * 10})
9597
else:
9698
result = run_cli(
97-
["get", "metric", "--project", project, "--run", run_name, "--metric", "val/loss"]
99+
[
100+
"get",
101+
"metric",
102+
"--project",
103+
project,
104+
"--run",
105+
run_name,
106+
"--metric",
107+
"val/loss",
108+
]
109+
)
110+
val_loss = (
111+
result["values"][-1]["value"]
112+
if result and result.get("values")
113+
else None
98114
)
99-
val_loss = result["values"][-1]["value"] if result and result.get("values") else None
100115
print(f" [AGENT] Attempt {attempt} succeeded! val_loss={val_loss}")
101-
attempts.append({"run": run_name, "status": "success", "val_loss": val_loss})
116+
attempts.append(
117+
{"run": run_name, "status": "success", "val_loss": val_loss}
118+
)
102119
break
103120

104121
print("\n[AGENT] Recovery history:")
@@ -136,7 +153,9 @@ def experiment_long_monitoring(project):
136153
]
137154

138155
print(" [AGENT] Starting long training run in background...")
139-
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
156+
proc = subprocess.Popen(
157+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
158+
)
140159

141160
all_alerts = []
142161

@@ -147,7 +166,9 @@ def experiment_long_monitoring(project):
147166
new_alerts = [a for a in alerts if a not in all_alerts]
148167
if new_alerts:
149168
for alert in new_alerts:
150-
print(f" [AGENT] New alert: [{alert.get('level', '?')}] {alert.get('title', '?')}")
169+
print(
170+
f" [AGENT] New alert: [{alert.get('level', '?')}] {alert.get('title', '?')}"
171+
)
151172
all_alerts.append(alert)
152173
since = datetime.now(timezone.utc).isoformat()
153174

tests/e2e-local/test_cli_agent_commands.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,16 @@ def test_best(temp_dir):
7070
assert {"value", "step", "config", "run"} <= entry.keys()
7171

7272
r2 = _cli(
73-
["best", "--project", PROJECT, "--metric", "accuracy", "--direction", "max", "--json"],
73+
[
74+
"best",
75+
"--project",
76+
PROJECT,
77+
"--metric",
78+
"accuracy",
79+
"--direction",
80+
"max",
81+
"--json",
82+
],
7483
temp_dir,
7584
)
7685
assert r2.returncode == 0
@@ -80,15 +89,24 @@ def test_best(temp_dir):
8089
def test_best_finished_filter(temp_dir):
8190
_seed(temp_dir)
8291
r = _cli(
83-
["best", "--project", FILTER_PROJECT, "--metric", "val/loss", "--json"], temp_dir
92+
["best", "--project", FILTER_PROJECT, "--metric", "val/loss", "--json"],
93+
temp_dir,
8494
)
8595
assert r.returncode == 0
8696
run_names = [e["run"] for e in json.loads(r.stdout)["ranking"]]
8797
assert "still-running" not in run_names
8898
assert len(run_names) == 2
8999

90100
r2 = _cli(
91-
["best", "--project", FILTER_PROJECT, "--metric", "val/loss", "--include-all", "--json"],
101+
[
102+
"best",
103+
"--project",
104+
FILTER_PROJECT,
105+
"--metric",
106+
"val/loss",
107+
"--include-all",
108+
"--json",
109+
],
92110
temp_dir,
93111
)
94112
assert r2.returncode == 0
@@ -110,7 +128,16 @@ def test_compare(temp_dir):
110128
assert {"val/loss", "accuracy"} <= run_entry["metrics"].keys()
111129

112130
r2 = _cli(
113-
["compare", "--project", PROJECT, "--runs", "run-lr0.01,run-lr0.1", "--metrics", "val/loss", "--json"],
131+
[
132+
"compare",
133+
"--project",
134+
PROJECT,
135+
"--runs",
136+
"run-lr0.01,run-lr0.1",
137+
"--metrics",
138+
"val/loss",
139+
"--json",
140+
],
114141
temp_dir,
115142
)
116143
assert r2.returncode == 0
@@ -120,15 +147,24 @@ def test_compare(temp_dir):
120147
def test_compare_finished_filter(temp_dir):
121148
_seed(temp_dir)
122149
r = _cli(
123-
["compare", "--project", FILTER_PROJECT, "--metrics", "val/loss", "--json"], temp_dir
150+
["compare", "--project", FILTER_PROJECT, "--metrics", "val/loss", "--json"],
151+
temp_dir,
124152
)
125153
assert r.returncode == 0
126154
run_names = [e["run"] for e in json.loads(r.stdout)["runs"]]
127155
assert "still-running" not in run_names
128156
assert len(run_names) == 2
129157

130158
r2 = _cli(
131-
["compare", "--project", FILTER_PROJECT, "--metrics", "val/loss", "--include-all", "--json"],
159+
[
160+
"compare",
161+
"--project",
162+
FILTER_PROJECT,
163+
"--metrics",
164+
"val/loss",
165+
"--include-all",
166+
"--json",
167+
],
132168
temp_dir,
133169
)
134170
assert r2.returncode == 0
@@ -147,10 +183,28 @@ def test_summary(temp_dir):
147183
assert data["num_runs"] == 3
148184
assert data["total_alerts"] >= 1
149185
for run_entry in data["runs"]:
150-
assert {"run", "status", "last_step", "num_logs", "config", "metric_value"} <= run_entry.keys()
186+
assert {
187+
"run",
188+
"status",
189+
"last_step",
190+
"num_logs",
191+
"config",
192+
"metric_value",
193+
} <= run_entry.keys()
151194

152195

153196
def test_best_error_cases(temp_dir):
154197
_seed(temp_dir)
155-
assert _cli(["best", "--project", "nope", "--metric", "loss", "--json"], temp_dir).returncode != 0
156-
assert _cli(["best", "--project", PROJECT, "--metric", "nonexistent", "--json"], temp_dir).returncode != 0
198+
assert (
199+
_cli(
200+
["best", "--project", "nope", "--metric", "loss", "--json"], temp_dir
201+
).returncode
202+
!= 0
203+
)
204+
assert (
205+
_cli(
206+
["best", "--project", PROJECT, "--metric", "nonexistent", "--json"],
207+
temp_dir,
208+
).returncode
209+
!= 0
210+
)

0 commit comments

Comments
 (0)