|
| 1 | +"""Task execution for BalatroLLM runs.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +from balatrobot import BalatroInstance |
| 8 | +from balatrobot import Config as BalatrobotConfig |
| 9 | + |
| 10 | +from .bot import Bot |
| 11 | +from .config import Config, Task |
| 12 | + |
| 13 | + |
| 14 | +@dataclass |
| 15 | +class Executor: |
| 16 | + """Executes tasks with parallelism.""" |
| 17 | + |
| 18 | + config: Config |
| 19 | + tasks: list[Task] |
| 20 | + runs_dir: Path = field(default_factory=Path.cwd) |
| 21 | + |
| 22 | + _instances: dict[int, BalatroInstance] = field( |
| 23 | + default_factory=dict, init=False, repr=False |
| 24 | + ) |
| 25 | + _port_pool: asyncio.Queue[int] = field( |
| 26 | + default_factory=asyncio.Queue, init=False, repr=False |
| 27 | + ) |
| 28 | + _shutdown: asyncio.Event = field( |
| 29 | + default_factory=asyncio.Event, init=False, repr=False |
| 30 | + ) |
| 31 | + |
| 32 | + async def run(self) -> None: |
| 33 | + """Execute all tasks.""" |
| 34 | + ports = range(self.config.port, self.config.port + self.config.parallel) |
| 35 | + try: |
| 36 | + await self._start_instances(ports) |
| 37 | + await self._execute_tasks() |
| 38 | + except asyncio.CancelledError: |
| 39 | + print("\nInterrupted! Cleaning up...") |
| 40 | + raise |
| 41 | + finally: |
| 42 | + await self._stop_instances() |
| 43 | + print("Done.") |
| 44 | + |
| 45 | + async def _start_instances(self, ports: range) -> None: |
| 46 | + """Start Balatro instances.""" |
| 47 | + cfg = BalatrobotConfig.from_env() |
| 48 | + for port in ports: |
| 49 | + instance = BalatroInstance(cfg, port=port) |
| 50 | + await instance.start() |
| 51 | + self._instances[port] = instance |
| 52 | + await self._port_pool.put(port) |
| 53 | + |
| 54 | + async def _stop_instances(self) -> None: |
| 55 | + """Stop all instances.""" |
| 56 | + await asyncio.gather( |
| 57 | + *(i.stop() for i in self._instances.values()), |
| 58 | + return_exceptions=True, |
| 59 | + ) |
| 60 | + self._instances.clear() |
| 61 | + |
| 62 | + async def _execute_tasks(self) -> None: |
| 63 | + """Execute tasks with port pool.""" |
| 64 | + total = len(self.tasks) |
| 65 | + count = 0 |
| 66 | + |
| 67 | + async def run_task(task: Task) -> None: |
| 68 | + nonlocal count |
| 69 | + if self._shutdown.is_set(): |
| 70 | + return |
| 71 | + port = await self._port_pool.get() |
| 72 | + try: |
| 73 | + count += 1 |
| 74 | + print(f"Running {task} ({count}/{total})") |
| 75 | + bot = Bot(task=task, config=self.config, port=port) |
| 76 | + async with bot: |
| 77 | + await bot.play(self.runs_dir) |
| 78 | + finally: |
| 79 | + await self._port_pool.put(port) |
| 80 | + |
| 81 | + pending = [asyncio.create_task(run_task(t)) for t in self.tasks] |
| 82 | + try: |
| 83 | + await asyncio.gather(*pending) |
| 84 | + except asyncio.CancelledError: |
| 85 | + self._shutdown.set() |
| 86 | + for t in pending: |
| 87 | + t.cancel() |
| 88 | + await asyncio.gather(*pending, return_exceptions=True) |
| 89 | + raise |
0 commit comments