Skip to content

Commit 2540725

Browse files
committed
feat(collector): flatten the structure of the stats
1 parent 075f1d5 commit 2540725

1 file changed

Lines changed: 89 additions & 113 deletions

File tree

src/balatrollm/collector.py

Lines changed: 89 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import re
55
import statistics
6+
from collections import Counter
67
from dataclasses import asdict, dataclass, field
78
from datetime import datetime
89
from pathlib import Path
@@ -39,54 +40,41 @@ def _generate_run_dir(task: Task, base_dir: Path) -> Path:
3940

4041

4142
@dataclass
42-
class AggregatedStats:
43-
"""Aggregated statistics for token usage, costs, and timing."""
44-
45-
input_tokens: int | float
46-
output_tokens: int | float
47-
input_cost: float
48-
output_cost: float
49-
total_cost: float
50-
time_ms: float
51-
43+
class Stats:
44+
"""Complete statistics for a game run (flat structure)."""
5245

53-
@dataclass
54-
class CallStats:
55-
"""Statistics for tool call outcomes."""
46+
# Outcome
47+
run_won: bool
48+
run_completed: bool
49+
final_ante: int
50+
final_round: int
5651

57-
successful: int = 0
58-
error: int = 0
59-
failed: int = 0
60-
total: int = 0
52+
# Provider distribution
53+
providers: dict[str, int]
6154

55+
# Call statistics
56+
calls_total: int
57+
calls_success: int
58+
calls_error: int
59+
calls_failed: int
6260

63-
@dataclass
64-
class Stats:
65-
"""Complete statistics for a game run."""
61+
# Token statistics
62+
tokens_in_total: int
63+
tokens_out_total: int
64+
tokens_in_avg: float
65+
tokens_out_avg: float
66+
tokens_in_std: float
67+
tokens_out_std: float
6668

67-
won: bool
68-
completed: bool
69-
ante_reached: int
70-
final_round: int
71-
providers: list[str]
72-
calls: CallStats
73-
total: AggregatedStats
74-
average: AggregatedStats
75-
std_dev: AggregatedStats
69+
# Timing statistics
70+
time_total_ms: int
71+
time_avg_ms: float
72+
time_std_ms: float
7673

77-
@classmethod
78-
def from_dict(cls, data: dict[str, Any]) -> "Stats":
79-
return cls(
80-
won=data["won"],
81-
completed=data["completed"],
82-
ante_reached=data["ante_reached"],
83-
final_round=data["final_round"],
84-
providers=data["providers"],
85-
calls=CallStats(**data["calls"]),
86-
total=AggregatedStats(**data["total"]),
87-
average=AggregatedStats(**data["average"]),
88-
std_dev=AggregatedStats(**data["std_dev"]),
89-
)
74+
# Cost statistics
75+
cost_total: float
76+
cost_avg: float
77+
cost_std: float
9078

9179

9280
@dataclass
@@ -155,27 +143,40 @@ def __init__(self, task: Task, base_dir: Path) -> None:
155143

156144
self.task = task
157145
self._request_count = 0
158-
self.call_stats = CallStats()
159146

160-
# Write task and strategy for benchmark analysis
147+
# Call tracking
148+
self._calls_success = 0
149+
self._calls_error = 0
150+
self._calls_failed = 0
151+
self._calls_total = 0
152+
153+
# Write task with structured model for benchmark analysis
154+
vendor, model_name = task.model.split("/", 1)
155+
task_data = {
156+
"model": {"vendor": vendor, "name": model_name},
157+
"seed": task.seed,
158+
"deck": task.deck,
159+
"stake": task.stake,
160+
"strategy": task.strategy,
161+
}
161162
manifest = StrategyManifest.from_file(task.strategy)
162163
with (self.run_dir / "task.json").open("w") as f:
163-
json.dump(asdict(self.task), f, indent=2)
164+
json.dump(task_data, f, indent=2)
164165
with (self.run_dir / "strategy.json").open("w") as f:
165166
json.dump(asdict(manifest), f, indent=2)
166167

167168
def record_call(self, outcome: Literal["successful", "error", "failed"]) -> None:
168169
"""Record a call outcome."""
169170
match outcome:
170171
case "successful":
171-
self.call_stats.successful += 1
172+
self._calls_success += 1
172173
case "error":
173-
self.call_stats.error += 1
174+
self._calls_error += 1
174175
case "failed":
175-
self.call_stats.failed += 1
176+
self._calls_failed += 1
176177
case _:
177178
raise ValueError(f"Invalid call outcome: {outcome}")
178-
self.call_stats.total += 1
179+
self._calls_total += 1
179180

180181
def write_request(self, body: dict[str, Any]) -> str:
181182
"""Write request to requests.jsonl. Returns custom_id."""
@@ -233,85 +234,60 @@ def _calculate_stats(self) -> Stats:
233234
assert len(responses) >= 2, "Expected at least two responses"
234235

235236
################################################################################
236-
# Populate list for each stat type
237+
# Populate lists for each stat type and count providers
237238
################################################################################
238239

239-
stats: dict[str, list] = {
240-
"providers": [],
241-
"input_tokens": [],
242-
"output_tokens": [],
243-
"input_cost": [],
244-
"output_cost": [],
245-
"total_cost": [],
246-
"time_ms": [],
247-
}
240+
provider_counts: Counter[str] = Counter()
241+
input_tokens: list[int] = []
242+
output_tokens: list[int] = []
243+
total_costs: list[float] = []
244+
time_ms_list: list[int] = []
245+
248246
for res in responses:
249247
if res.response is not None and res.response.status_code == 200:
250248
body = res.response.body
251249
if "provider" in body:
252-
stats["providers"].append(body["provider"])
250+
provider_counts[body["provider"]] += 1
253251

254252
usage = body.get("usage", {})
255-
stats["input_tokens"].append(usage.get("prompt_tokens", 0))
256-
stats["output_tokens"].append(usage.get("completion_tokens", 0))
257-
258-
cost_details = usage.get("cost_details", {})
259-
stats["input_cost"].append(
260-
cost_details.get("upstream_inference_prompt_cost", 0)
261-
)
262-
stats["output_cost"].append(
263-
cost_details.get("upstream_inference_completions_cost", 0)
264-
)
265-
stats["total_cost"].append(usage.get("cost", 0))
266-
stats["time_ms"].append(int(res.id) - int(res.response.request_id))
253+
input_tokens.append(usage.get("prompt_tokens", 0))
254+
output_tokens.append(usage.get("completion_tokens", 0))
255+
total_costs.append(usage.get("cost", 0))
256+
time_ms_list.append(int(res.id) - int(res.response.request_id))
267257

268258
################################################################################
269259
# Compute aggregated stats
270260
################################################################################
271261

272-
n = len(stats["input_tokens"])
273-
274-
total = AggregatedStats(
275-
input_tokens=sum(stats["input_tokens"]),
276-
output_tokens=sum(stats["output_tokens"]),
277-
input_cost=sum(stats["input_cost"]),
278-
output_cost=sum(stats["output_cost"]),
279-
total_cost=sum(stats["total_cost"]),
280-
time_ms=sum(stats["time_ms"]),
281-
)
282-
283-
average = AggregatedStats(
284-
input_tokens=total.input_tokens / n,
285-
output_tokens=total.output_tokens / n,
286-
input_cost=total.input_cost / n,
287-
output_cost=total.output_cost / n,
288-
total_cost=total.total_cost / n,
289-
time_ms=total.time_ms / n,
290-
)
291-
292-
std_dev = AggregatedStats(
293-
input_tokens=statistics.stdev(stats["input_tokens"]),
294-
output_tokens=statistics.stdev(stats["output_tokens"]),
295-
input_cost=statistics.stdev(stats["input_cost"]),
296-
output_cost=statistics.stdev(stats["output_cost"]),
297-
total_cost=statistics.stdev(stats["total_cost"]),
298-
time_ms=statistics.stdev(stats["time_ms"]),
299-
)
300-
301-
################################################################################
302-
# Compute Stats from the final gamestate
303-
################################################################################
304-
262+
n = len(input_tokens)
305263
gamestate = gamestates[-1]
306264

307265
return Stats(
308-
won=gamestate["won"],
309-
completed=gamestate["state"] == "GAME_OVER" or gamestate["won"],
310-
ante_reached=gamestate["ante_num"],
266+
# Outcome
267+
run_won=gamestate["won"],
268+
run_completed=gamestate["state"] == "GAME_OVER" or gamestate["won"],
269+
final_ante=gamestate["ante_num"],
311270
final_round=gamestate["round_num"],
312-
providers=stats["providers"],
313-
calls=self.call_stats,
314-
total=total,
315-
average=average,
316-
std_dev=std_dev,
271+
# Provider distribution
272+
providers=dict(provider_counts),
273+
# Call statistics
274+
calls_total=self._calls_total,
275+
calls_success=self._calls_success,
276+
calls_error=self._calls_error,
277+
calls_failed=self._calls_failed,
278+
# Token statistics
279+
tokens_in_total=sum(input_tokens),
280+
tokens_out_total=sum(output_tokens),
281+
tokens_in_avg=sum(input_tokens) / n,
282+
tokens_out_avg=sum(output_tokens) / n,
283+
tokens_in_std=statistics.stdev(input_tokens),
284+
tokens_out_std=statistics.stdev(output_tokens),
285+
# Timing statistics
286+
time_total_ms=sum(time_ms_list),
287+
time_avg_ms=sum(time_ms_list) / n,
288+
time_std_ms=statistics.stdev(time_ms_list),
289+
# Cost statistics
290+
cost_total=sum(total_costs),
291+
cost_avg=sum(total_costs) / n,
292+
cost_std=statistics.stdev(total_costs),
317293
)

0 commit comments

Comments
 (0)