Skip to content

Commit 18e7341

Browse files
committed
feat: update data collection to include invalid responses
1 parent 1c8bc14 commit 18e7341

1 file changed

Lines changed: 8 additions & 14 deletions

File tree

src/balatrollm/data_collection.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class RunStats:
7272
money_spent: Total money spent during the run.
7373
hands_played: Dictionary mapping hand types to play counts.
7474
successful_calls: Number of successful LLM API calls.
75-
error_calls: List of error messages from failed LLM calls.
76-
failed_calls: List of failure messages from LLM calls.
75+
invalid_calls: Number of invalid tool call responses from LLM.
76+
failed_calls: List of failure tool call from LLM calls.
7777
avg_input_tokens: Average input tokens per LLM call.
7878
avg_output_tokens: Average output tokens per LLM call.
7979
avg_reasoning_tokens: Average reasoning tokens per LLM call.
@@ -102,7 +102,7 @@ class RunStats:
102102

103103
# LLM Performance
104104
successful_calls: int = 0
105-
error_calls: list[str] = field(default_factory=list)
105+
invalid_responses: int = 0
106106
failed_calls: list[str] = field(default_factory=list)
107107
avg_input_tokens: float = 0.0
108108
avg_output_tokens: float = 0.0
@@ -132,6 +132,7 @@ class RunStatsCollector:
132132
run_dir: Path
133133
config: Config
134134
request_count: int = 0
135+
failed_calls: list[str] = field(default_factory=list)
135136

136137
def __post_init__(self):
137138
"""Create directory structure and write config."""
@@ -262,6 +263,7 @@ def calculate_stats(self) -> RunStats:
262263
stats.ante_reached = (
263264
max(1, (stats.final_round // 3) + 1) if stats.final_round > 0 else 1
264265
)
266+
stats.failed_calls = self.failed_calls
265267

266268
# Strategy Metrics
267269
for state in game_states:
@@ -370,6 +372,7 @@ def calculate_stats(self) -> RunStats:
370372
):
371373
stats.successful_calls += 1
372374
body = response.get("response", {}).get("body", {})
375+
message = body.get("choices", [{}])[0].get("message", {})
373376
usage = body.get("usage", {})
374377

375378
if "prompt_tokens" in usage:
@@ -380,17 +383,8 @@ def calculate_stats(self) -> RunStats:
380383
reasoning_tokens.append(usage["reasoning_tokens"])
381384
if "total_tokens" in usage:
382385
total_tokens.append(usage["total_tokens"])
383-
384-
elif response.get("error") is not None:
385-
error = response.get("error", {})
386-
error_msg = f"{error.get('code', 'UnknownError')}: {error.get('message', '')}"
387-
stats.error_calls.append(error_msg)
388-
389-
elif response.get("response", {}).get("status_code", 200) != 200:
390-
status_code = response.get("response", {}).get("status_code")
391-
body = response.get("response", {}).get("body", {})
392-
error_msg = body.get("error", f"HTTP {status_code}")
393-
stats.failed_calls.append(f"Status {status_code}: {error_msg}")
386+
if message.get("tool_calls") is None:
387+
stats.invalid_responses += 1
394388

395389
# Calculate totals
396390
stats.total_input_tokens = sum(input_tokens)

0 commit comments

Comments
 (0)