Skip to content

Commit 5b0d95c

Browse files
committed
feat(bot): add finish reason tracking and separate error/failed counters
Add FinishReason type alias with 8 distinct reasons (won, lost, llm_abort, connection_abort, consecutive_error_calls, consecutive_failed_calls, unexpected_error). Track error calls (invalid LLM responses) and failed calls (valid tool call but execution failed) with separate consecutive counters. Add batch.json tracking for best run and previous.json for last completed run. Include tokens/cost tracking and finish reason in stats output.
1 parent 01f18c6 commit 5b0d95c

3 files changed

Lines changed: 201 additions & 24 deletions

File tree

src/balatrollm/bot.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ChatCompletionError,
1616
ChatCompletionResponse,
1717
Collector,
18+
FinishReason,
1819
Stats,
1920
)
2021
from .config import Config, Task, get_model_config
@@ -44,12 +45,16 @@ def __init__(self, task: Task, config: Config, port: int | None = None) -> None:
4445
self._llm: LLMClient | None = None
4546
self._collector: Collector | None = None
4647

47-
self._consecutive_failures: int = 0
48-
self._max_consecutive_failures: int = 3
4948
self._last_error_msg: str | None = None
5049
self._last_failed_msg: str | None = None
5150
self._history: list[dict[str, Any]] = []
5251

52+
# Finish reason tracking
53+
self._finish_reason: FinishReason | None = None
54+
# Separate counters for error calls vs failed calls
55+
self._consecutive_errors: int = 0
56+
self._consecutive_faileds: int = 0
57+
5358
async def __aenter__(self) -> "Bot":
5459
"""Initialize all clients."""
5560
self._balatro = BalatroClient(
@@ -111,6 +116,7 @@ async def _wait_for_menu(self, timeout: float = 10.0) -> None:
111116
logger.debug(f"Gamestate check failed: {e}")
112117
await asyncio.sleep(0.5)
113118

119+
self._finish_reason = "connection_abort"
114120
raise BotError(f"Timeout waiting for MENU state after {timeout}s")
115121

116122
async def play(self, runs_dir: Path = Path.cwd()) -> Stats:
@@ -124,11 +130,13 @@ async def play(self, runs_dir: Path = Path.cwd()) -> Stats:
124130
try:
125131
await self._balatro.call("gamestate")
126132
except (httpx.ConnectError, httpx.TimeoutException) as e:
133+
self._finish_reason = "connection_abort"
127134
raise BotError(
128135
f"Cannot connect to Balatro on {self.config.host}:{self.port}. "
129136
"Make sure Balatro instance started correctly."
130137
) from e
131138
except Exception as e:
139+
self._finish_reason = "connection_abort"
132140
raise BotError(f"Failed to connect to Balatro: {e}") from e
133141

134142
self._collector = Collector(self.task, runs_dir)
@@ -153,19 +161,23 @@ async def play(self, runs_dir: Path = Path.cwd()) -> Stats:
153161
logger.error("Game ended due to bot error")
154162
raise
155163
except Exception as e:
164+
self._finish_reason = "unexpected_error"
156165
logger.exception("Unexpected error occurred during gameplay")
157166
raise BotError(f"Unexpected error: {e}") from e
158167
finally:
159168
if self._collector:
160169
try:
161-
self._collector.write_stats()
170+
reason: FinishReason = self._finish_reason or "unexpected_error"
171+
self._collector.write_stats(reason)
162172
logger.info("Stats written")
163173
except Exception as e:
164174
logger.debug(
165175
f"Could not write stats (normal if run failed early): {e}"
166176
)
167177

168-
return self._collector._calculate_stats()
178+
return self._collector._calculate_stats(
179+
self._finish_reason or "unexpected_error"
180+
)
169181

170182
async def _run_game_loop(self, gamestate: dict[str, Any]) -> None:
171183
"""Main game loop."""
@@ -175,6 +187,7 @@ async def _run_game_loop(self, gamestate: dict[str, Any]) -> None:
175187

176188
while True:
177189
if gamestate.get("won", False):
190+
self._finish_reason = "won"
178191
logger.info("Game won! Waiting for GAME_OVER state...")
179192
break
180193

@@ -194,6 +207,7 @@ async def _run_game_loop(self, gamestate: dict[str, Any]) -> None:
194207
# NOTE: This bot always selects and never skips blinds
195208
gamestate = await self._balatro.call("select")
196209
case "GAME_OVER":
210+
self._finish_reason = "lost"
197211
logger.info("Game over!")
198212
break
199213
case _:
@@ -275,6 +289,7 @@ async def _get_llm_response(self, gamestate: dict[str, Any]) -> ChatCompletion:
275289
custom_id=custom_id,
276290
error=ChatCompletionError(code="timeout", message=str(e)),
277291
)
292+
self._finish_reason = "llm_abort"
278293
raise BotError("3 consecutive LLM timeouts") from e
279294

280295
except LLMClientError as e:
@@ -283,6 +298,7 @@ async def _get_llm_response(self, gamestate: dict[str, Any]) -> ChatCompletion:
283298
custom_id=custom_id,
284299
error=ChatCompletionError(code="error", message=str(e)),
285300
)
301+
self._finish_reason = "llm_abort"
286302
raise BotError(f"LLM error: {e}") from e
287303

288304
async def _execute_tool_call(self, response: ChatCompletion) -> dict[str, Any]:
@@ -327,7 +343,10 @@ async def _execute_tool_call(self, response: ChatCompletion) -> dict[str, Any]:
327343
logger.info(f"Executing: {fn_name}({fn_args})")
328344
gamestate = await self._balatro.call(fn_name, fn_args)
329345

330-
self._consecutive_failures = 0
346+
self._collector.reset_failures()
347+
# Reset both consecutive counters on success
348+
self._consecutive_errors = 0
349+
self._consecutive_faileds = 0
331350
self._last_error_msg = None
332351
self._last_failed_msg = None
333352
self._collector.record_call("successful")
@@ -345,6 +364,7 @@ async def _execute_tool_call(self, response: ChatCompletion) -> dict[str, Any]:
345364
try:
346365
return await self._balatro.call("gamestate")
347366
except Exception:
367+
self._finish_reason = "connection_abort"
348368
raise BotError(f"Game unresponsive after transport error: {e}") from e
349369

350370
async def _handle_error_call(self, msg: str) -> dict[str, Any]:
@@ -354,11 +374,16 @@ async def _handle_error_call(self, msg: str) -> dict[str, Any]:
354374

355375
logger.warning(f"Error call: {msg}")
356376
self._last_error_msg = msg
357-
self._consecutive_failures += 1
377+
self._collector.record_failure()
358378
self._collector.record_call("error")
359379

360-
if self._consecutive_failures >= self._max_consecutive_failures:
361-
raise BotError("Too many consecutive error/failed calls")
380+
# Track consecutive error calls separately
381+
self._consecutive_errors += 1
382+
self._consecutive_faileds = 0
383+
384+
if self._consecutive_errors >= Collector.MAX_CONSECUTIVE_FAILURES:
385+
self._finish_reason = "consecutive_error_calls"
386+
raise BotError("Too many consecutive error calls")
362387

363388
return await self._balatro.call("gamestate")
364389

@@ -369,10 +394,15 @@ async def _handle_failed_call(self, msg: str) -> dict[str, Any]:
369394

370395
logger.warning(f"Failed call: {msg}")
371396
self._last_failed_msg = msg
372-
self._consecutive_failures += 1
397+
self._collector.record_failure()
373398
self._collector.record_call("failed")
374399

375-
if self._consecutive_failures >= self._max_consecutive_failures:
376-
raise BotError("Too many consecutive error/failed calls")
400+
# Track consecutive failed calls separately
401+
self._consecutive_faileds += 1
402+
self._consecutive_errors = 0
403+
404+
if self._consecutive_faileds >= Collector.MAX_CONSECUTIVE_FAILURES:
405+
self._finish_reason = "consecutive_failed_calls"
406+
raise BotError("Too many consecutive failed calls")
377407

378408
return await self._balatro.call("gamestate")

0 commit comments

Comments
 (0)