|
3 | 3 | import json |
4 | 4 | import re |
5 | 5 | import statistics |
| 6 | +from collections import Counter |
6 | 7 | from dataclasses import asdict, dataclass, field |
7 | 8 | from datetime import datetime |
8 | 9 | from pathlib import Path |
@@ -39,54 +40,41 @@ def _generate_run_dir(task: Task, base_dir: Path) -> Path: |
39 | 40 |
|
40 | 41 |
|
41 | 42 | @dataclass |
42 | | -class AggregatedStats: |
43 | | - """Aggregated statistics for token usage, costs, and timing.""" |
44 | | - |
45 | | - input_tokens: int | float |
46 | | - output_tokens: int | float |
47 | | - input_cost: float |
48 | | - output_cost: float |
49 | | - total_cost: float |
50 | | - time_ms: float |
51 | | - |
| 43 | +class Stats: |
| 44 | + """Complete statistics for a game run (flat structure).""" |
52 | 45 |
|
53 | | -@dataclass |
54 | | -class CallStats: |
55 | | - """Statistics for tool call outcomes.""" |
| 46 | + # Outcome |
| 47 | + run_won: bool |
| 48 | + run_completed: bool |
| 49 | + final_ante: int |
| 50 | + final_round: int |
56 | 51 |
|
57 | | - successful: int = 0 |
58 | | - error: int = 0 |
59 | | - failed: int = 0 |
60 | | - total: int = 0 |
| 52 | + # Provider distribution |
| 53 | + providers: dict[str, int] |
61 | 54 |
|
| 55 | + # Call statistics |
| 56 | + calls_total: int |
| 57 | + calls_success: int |
| 58 | + calls_error: int |
| 59 | + calls_failed: int |
62 | 60 |
|
63 | | -@dataclass |
64 | | -class Stats: |
65 | | - """Complete statistics for a game run.""" |
| 61 | + # Token statistics |
| 62 | + tokens_in_total: int |
| 63 | + tokens_out_total: int |
| 64 | + tokens_in_avg: float |
| 65 | + tokens_out_avg: float |
| 66 | + tokens_in_std: float |
| 67 | + tokens_out_std: float |
66 | 68 |
|
67 | | - won: bool |
68 | | - completed: bool |
69 | | - ante_reached: int |
70 | | - final_round: int |
71 | | - providers: list[str] |
72 | | - calls: CallStats |
73 | | - total: AggregatedStats |
74 | | - average: AggregatedStats |
75 | | - std_dev: AggregatedStats |
| 69 | + # Timing statistics |
| 70 | + time_total_ms: int |
| 71 | + time_avg_ms: float |
| 72 | + time_std_ms: float |
76 | 73 |
|
77 | | - @classmethod |
78 | | - def from_dict(cls, data: dict[str, Any]) -> "Stats": |
79 | | - return cls( |
80 | | - won=data["won"], |
81 | | - completed=data["completed"], |
82 | | - ante_reached=data["ante_reached"], |
83 | | - final_round=data["final_round"], |
84 | | - providers=data["providers"], |
85 | | - calls=CallStats(**data["calls"]), |
86 | | - total=AggregatedStats(**data["total"]), |
87 | | - average=AggregatedStats(**data["average"]), |
88 | | - std_dev=AggregatedStats(**data["std_dev"]), |
89 | | - ) |
| 74 | + # Cost statistics |
| 75 | + cost_total: float |
| 76 | + cost_avg: float |
| 77 | + cost_std: float |
90 | 78 |
|
91 | 79 |
|
92 | 80 | @dataclass |
@@ -155,27 +143,40 @@ def __init__(self, task: Task, base_dir: Path) -> None: |
155 | 143 |
|
156 | 144 | self.task = task |
157 | 145 | self._request_count = 0 |
158 | | - self.call_stats = CallStats() |
159 | 146 |
|
160 | | - # Write task and strategy for benchmark analysis |
| 147 | + # Call tracking |
| 148 | + self._calls_success = 0 |
| 149 | + self._calls_error = 0 |
| 150 | + self._calls_failed = 0 |
| 151 | + self._calls_total = 0 |
| 152 | + |
| 153 | + # Write task with structured model for benchmark analysis |
| 154 | + vendor, model_name = task.model.split("/", 1) |
| 155 | + task_data = { |
| 156 | + "model": {"vendor": vendor, "name": model_name}, |
| 157 | + "seed": task.seed, |
| 158 | + "deck": task.deck, |
| 159 | + "stake": task.stake, |
| 160 | + "strategy": task.strategy, |
| 161 | + } |
161 | 162 | manifest = StrategyManifest.from_file(task.strategy) |
162 | 163 | with (self.run_dir / "task.json").open("w") as f: |
163 | | - json.dump(asdict(self.task), f, indent=2) |
| 164 | + json.dump(task_data, f, indent=2) |
164 | 165 | with (self.run_dir / "strategy.json").open("w") as f: |
165 | 166 | json.dump(asdict(manifest), f, indent=2) |
166 | 167 |
|
167 | 168 | def record_call(self, outcome: Literal["successful", "error", "failed"]) -> None: |
168 | 169 | """Record a call outcome.""" |
169 | 170 | match outcome: |
170 | 171 | case "successful": |
171 | | - self.call_stats.successful += 1 |
| 172 | + self._calls_success += 1 |
172 | 173 | case "error": |
173 | | - self.call_stats.error += 1 |
| 174 | + self._calls_error += 1 |
174 | 175 | case "failed": |
175 | | - self.call_stats.failed += 1 |
| 176 | + self._calls_failed += 1 |
176 | 177 | case _: |
177 | 178 | raise ValueError(f"Invalid call outcome: {outcome}") |
178 | | - self.call_stats.total += 1 |
| 179 | + self._calls_total += 1 |
179 | 180 |
|
180 | 181 | def write_request(self, body: dict[str, Any]) -> str: |
181 | 182 | """Write request to requests.jsonl. Returns custom_id.""" |
@@ -233,85 +234,60 @@ def _calculate_stats(self) -> Stats: |
233 | 234 | assert len(responses) >= 2, "Expected at least two responses" |
234 | 235 |
|
235 | 236 | ################################################################################ |
236 | | - # Populate list for each stat type |
| 237 | + # Populate lists for each stat type and count providers |
237 | 238 | ################################################################################ |
238 | 239 |
|
239 | | - stats: dict[str, list] = { |
240 | | - "providers": [], |
241 | | - "input_tokens": [], |
242 | | - "output_tokens": [], |
243 | | - "input_cost": [], |
244 | | - "output_cost": [], |
245 | | - "total_cost": [], |
246 | | - "time_ms": [], |
247 | | - } |
| 240 | + provider_counts: Counter[str] = Counter() |
| 241 | + input_tokens: list[int] = [] |
| 242 | + output_tokens: list[int] = [] |
| 243 | + total_costs: list[float] = [] |
| 244 | + time_ms_list: list[int] = [] |
| 245 | + |
248 | 246 | for res in responses: |
249 | 247 | if res.response is not None and res.response.status_code == 200: |
250 | 248 | body = res.response.body |
251 | 249 | if "provider" in body: |
252 | | - stats["providers"].append(body["provider"]) |
| 250 | + provider_counts[body["provider"]] += 1 |
253 | 251 |
|
254 | 252 | usage = body.get("usage", {}) |
255 | | - stats["input_tokens"].append(usage.get("prompt_tokens", 0)) |
256 | | - stats["output_tokens"].append(usage.get("completion_tokens", 0)) |
257 | | - |
258 | | - cost_details = usage.get("cost_details", {}) |
259 | | - stats["input_cost"].append( |
260 | | - cost_details.get("upstream_inference_prompt_cost", 0) |
261 | | - ) |
262 | | - stats["output_cost"].append( |
263 | | - cost_details.get("upstream_inference_completions_cost", 0) |
264 | | - ) |
265 | | - stats["total_cost"].append(usage.get("cost", 0)) |
266 | | - stats["time_ms"].append(int(res.id) - int(res.response.request_id)) |
| 253 | + input_tokens.append(usage.get("prompt_tokens", 0)) |
| 254 | + output_tokens.append(usage.get("completion_tokens", 0)) |
| 255 | + total_costs.append(usage.get("cost", 0)) |
| 256 | + time_ms_list.append(int(res.id) - int(res.response.request_id)) |
267 | 257 |
|
268 | 258 | ################################################################################ |
269 | 259 | # Compute aggregated stats |
270 | 260 | ################################################################################ |
271 | 261 |
|
272 | | - n = len(stats["input_tokens"]) |
273 | | - |
274 | | - total = AggregatedStats( |
275 | | - input_tokens=sum(stats["input_tokens"]), |
276 | | - output_tokens=sum(stats["output_tokens"]), |
277 | | - input_cost=sum(stats["input_cost"]), |
278 | | - output_cost=sum(stats["output_cost"]), |
279 | | - total_cost=sum(stats["total_cost"]), |
280 | | - time_ms=sum(stats["time_ms"]), |
281 | | - ) |
282 | | - |
283 | | - average = AggregatedStats( |
284 | | - input_tokens=total.input_tokens / n, |
285 | | - output_tokens=total.output_tokens / n, |
286 | | - input_cost=total.input_cost / n, |
287 | | - output_cost=total.output_cost / n, |
288 | | - total_cost=total.total_cost / n, |
289 | | - time_ms=total.time_ms / n, |
290 | | - ) |
291 | | - |
292 | | - std_dev = AggregatedStats( |
293 | | - input_tokens=statistics.stdev(stats["input_tokens"]), |
294 | | - output_tokens=statistics.stdev(stats["output_tokens"]), |
295 | | - input_cost=statistics.stdev(stats["input_cost"]), |
296 | | - output_cost=statistics.stdev(stats["output_cost"]), |
297 | | - total_cost=statistics.stdev(stats["total_cost"]), |
298 | | - time_ms=statistics.stdev(stats["time_ms"]), |
299 | | - ) |
300 | | - |
301 | | - ################################################################################ |
302 | | - # Compute Stats from the final gamestate |
303 | | - ################################################################################ |
304 | | - |
| 262 | + n = len(input_tokens) |
305 | 263 | gamestate = gamestates[-1] |
306 | 264 |
|
307 | 265 | return Stats( |
308 | | - won=gamestate["won"], |
309 | | - completed=gamestate["state"] == "GAME_OVER" or gamestate["won"], |
310 | | - ante_reached=gamestate["ante_num"], |
| 266 | + # Outcome |
| 267 | + run_won=gamestate["won"], |
| 268 | + run_completed=gamestate["state"] == "GAME_OVER" or gamestate["won"], |
| 269 | + final_ante=gamestate["ante_num"], |
311 | 270 | final_round=gamestate["round_num"], |
312 | | - providers=stats["providers"], |
313 | | - calls=self.call_stats, |
314 | | - total=total, |
315 | | - average=average, |
316 | | - std_dev=std_dev, |
| 271 | + # Provider distribution |
| 272 | + providers=dict(provider_counts), |
| 273 | + # Call statistics |
| 274 | + calls_total=self._calls_total, |
| 275 | + calls_success=self._calls_success, |
| 276 | + calls_error=self._calls_error, |
| 277 | + calls_failed=self._calls_failed, |
| 278 | + # Token statistics |
| 279 | + tokens_in_total=sum(input_tokens), |
| 280 | + tokens_out_total=sum(output_tokens), |
| 281 | + tokens_in_avg=sum(input_tokens) / n, |
| 282 | + tokens_out_avg=sum(output_tokens) / n, |
| 283 | + tokens_in_std=statistics.stdev(input_tokens), |
| 284 | + tokens_out_std=statistics.stdev(output_tokens), |
| 285 | + # Timing statistics |
| 286 | + time_total_ms=sum(time_ms_list), |
| 287 | + time_avg_ms=sum(time_ms_list) / n, |
| 288 | + time_std_ms=statistics.stdev(time_ms_list), |
| 289 | + # Cost statistics |
| 290 | + cost_total=sum(total_costs), |
| 291 | + cost_avg=sum(total_costs) / n, |
| 292 | + cost_std=statistics.stdev(total_costs), |
317 | 293 | ) |
0 commit comments