Skip to content

Commit 3f80e3f

Browse files
committed
feat: add executor to start/stop balatro instances and tasks
1 parent bc05dd8 commit 3f80e3f

1 file changed

Lines changed: 89 additions & 0 deletions

File tree

src/balatrollm/executor.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Task execution for BalatroLLM runs."""
2+
3+
import asyncio
4+
from dataclasses import dataclass, field
5+
from pathlib import Path
6+
7+
from balatrobot import BalatroInstance
8+
from balatrobot import Config as BalatrobotConfig
9+
10+
from .bot import Bot
11+
from .config import Config, Task
12+
13+
14+
@dataclass
15+
class Executor:
16+
"""Executes tasks with parallelism."""
17+
18+
config: Config
19+
tasks: list[Task]
20+
runs_dir: Path = field(default_factory=Path.cwd)
21+
22+
_instances: dict[int, BalatroInstance] = field(
23+
default_factory=dict, init=False, repr=False
24+
)
25+
_port_pool: asyncio.Queue[int] = field(
26+
default_factory=asyncio.Queue, init=False, repr=False
27+
)
28+
_shutdown: asyncio.Event = field(
29+
default_factory=asyncio.Event, init=False, repr=False
30+
)
31+
32+
async def run(self) -> None:
33+
"""Execute all tasks."""
34+
ports = range(self.config.port, self.config.port + self.config.parallel)
35+
try:
36+
await self._start_instances(ports)
37+
await self._execute_tasks()
38+
except asyncio.CancelledError:
39+
print("\nInterrupted! Cleaning up...")
40+
raise
41+
finally:
42+
await self._stop_instances()
43+
print("Done.")
44+
45+
async def _start_instances(self, ports: range) -> None:
46+
"""Start Balatro instances."""
47+
cfg = BalatrobotConfig.from_env()
48+
for port in ports:
49+
instance = BalatroInstance(cfg, port=port)
50+
await instance.start()
51+
self._instances[port] = instance
52+
await self._port_pool.put(port)
53+
54+
async def _stop_instances(self) -> None:
55+
"""Stop all instances."""
56+
await asyncio.gather(
57+
*(i.stop() for i in self._instances.values()),
58+
return_exceptions=True,
59+
)
60+
self._instances.clear()
61+
62+
async def _execute_tasks(self) -> None:
63+
"""Execute tasks with port pool."""
64+
total = len(self.tasks)
65+
count = 0
66+
67+
async def run_task(task: Task) -> None:
68+
nonlocal count
69+
if self._shutdown.is_set():
70+
return
71+
port = await self._port_pool.get()
72+
try:
73+
count += 1
74+
print(f"Running {task} ({count}/{total})")
75+
bot = Bot(task=task, config=self.config, port=port)
76+
async with bot:
77+
await bot.play(self.runs_dir)
78+
finally:
79+
await self._port_pool.put(port)
80+
81+
pending = [asyncio.create_task(run_task(t)) for t in self.tasks]
82+
try:
83+
await asyncio.gather(*pending)
84+
except asyncio.CancelledError:
85+
self._shutdown.set()
86+
for t in pending:
87+
t.cancel()
88+
await asyncio.gather(*pending, return_exceptions=True)
89+
raise

0 commit comments

Comments
 (0)