Skip to content

Commit f1937a4

Browse files
committed
feat: update Bot to use new template system
1 parent 4a84cc8 commit f1937a4

1 file changed

Lines changed: 33 additions & 24 deletions

File tree

src/balatrollm/llm.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,36 @@ class BalatroLLMError(Exception):
4141
class TemplateManager:
4242
"""Lightweight helper for managing Jinja2 templates."""
4343

44-
def __init__(self, template_dir: Path):
45-
self.jinja_env = Environment(loader=FileSystemLoader(template_dir))
44+
def __init__(self, template_dir: Path, strategy: str):
45+
self.strategy = strategy
46+
self.strategy_dir = template_dir / strategy
47+
self.jinja_env = Environment(loader=FileSystemLoader(self.strategy_dir))
4648
self.jinja_env.filters["from_json"] = json.loads
4749

48-
def render_system_prompt(self, template_name: str) -> str:
49-
"""Render the system prompt template."""
50-
template = self.jinja_env.get_template(f"{template_name}.md.jinja")
50+
def render_strategy(self) -> str:
51+
"""Render the strategy template."""
52+
template = self.jinja_env.get_template("STRATEGY.md.jinja")
5153
return template.render()
5254

53-
def render_game_state(
54-
self, state_name: str, game_state: Dict[str, Any], responses: List[Any]
55-
) -> str:
55+
def render_gamestate(self, state_name: str, game_state: Dict[str, Any]) -> str:
5656
"""Render the game state template."""
57-
template = self.jinja_env.get_template("game_state.md.jinja")
57+
template = self.jinja_env.get_template("GAMESTATE.md.jinja")
5858
return template.render(
5959
state_name=state_name,
6060
game_state=game_state,
61-
responses=responses,
6261
)
6362

63+
def render_memory(self, responses: List[Any]) -> str:
64+
"""Render the memory template."""
65+
template = self.jinja_env.get_template("MEMORY.md.jinja")
66+
return template.render(responses=responses)
67+
68+
def load_tools(self) -> Dict[str, Any]:
69+
"""Load tools from the strategy-specific TOOLS.json file."""
70+
tools_file = self.strategy_dir / "TOOLS.json"
71+
with open(tools_file) as f:
72+
return json.load(f)
73+
6474

6575
@dataclass
6676
class Config:
@@ -69,7 +79,7 @@ class Config:
6979
model: str
7080
proxy_url: str = "http://localhost:4000"
7181
api_key: str = "sk-balatrollm-proxy-key"
72-
template: str = "system"
82+
template: str = "default"
7383

7484
@classmethod
7585
def from_environment(cls) -> "Config":
@@ -78,7 +88,7 @@ def from_environment(cls) -> "Config":
7888
model=os.getenv("LITELLM_MODEL", "cerebras-gpt-oss-120b"),
7989
proxy_url=os.getenv("LITELLM_PROXY_URL", "http://localhost:4000"),
8090
api_key=os.getenv("LITELLM_API_KEY", "sk-balatrollm-proxy-key"),
81-
template=os.getenv("BALATROLLM_TEMPLATE", "system"),
91+
template=os.getenv("BALATROLLM_TEMPLATE", "default"),
8292
)
8393

8494

@@ -94,13 +104,11 @@ def __init__(self, config: Config):
94104

95105
# Set up template manager
96106
template_dir = Path(__file__).parent / "templates"
97-
self.template_manager = TemplateManager(template_dir)
107+
self.template_manager = TemplateManager(template_dir, config.template)
98108
self.responses: list[ChatCompletion] = []
99109

100-
# Load tools from JSON file
101-
tools_file = Path(__file__).parent / "tools.json"
102-
with open(tools_file) as f:
103-
self.tools = json.load(f)
110+
# Load tools from strategy-specific file
111+
self.tools = self.template_manager.load_tools()
104112

105113
# Get project version from pyproject.toml
106114
self.project_version = self._get_project_version()
@@ -252,18 +260,19 @@ async def get_tool_call(self, game_state: dict):
252260

253261
# Generate prompt with error handling
254262
try:
255-
system_prompt = self.template_manager.render_system_prompt(
256-
self.config.template
257-
)
258-
user_prompt = self.template_manager.render_game_state(
259-
state_name, game_state, self.responses
263+
strategy_content = self.template_manager.render_strategy()
264+
gamestate_content = self.template_manager.render_gamestate(
265+
state_name, game_state
260266
)
267+
memory_content = self.template_manager.render_memory(self.responses)
261268
except Exception as e:
262269
logger.error(f"Template rendering failed: {e}")
263270
raise RuntimeError(f"Failed to generate prompts: {e}") from e
271+
272+
# Combine all content into user message
273+
user_content = f"{strategy_content}\n\n{gamestate_content}\n\n{memory_content}"
264274
messages = [
265-
{"role": "system", "content": system_prompt},
266-
{"role": "user", "content": user_prompt},
275+
{"role": "user", "content": user_content},
267276
]
268277

269278
# Select tools based on current state

0 commit comments

Comments
 (0)