@@ -41,26 +41,36 @@ class BalatroLLMError(Exception):
4141class TemplateManager :
4242 """Lightweight helper for managing Jinja2 templates."""
4343
44- def __init__ (self , template_dir : Path ):
45- self .jinja_env = Environment (loader = FileSystemLoader (template_dir ))
44+ def __init__ (self , template_dir : Path , strategy : str ):
45+ self .strategy = strategy
46+ self .strategy_dir = template_dir / strategy
47+ self .jinja_env = Environment (loader = FileSystemLoader (self .strategy_dir ))
4648 self .jinja_env .filters ["from_json" ] = json .loads
4749
48- def render_system_prompt (self , template_name : str ) -> str :
49- """Render the system prompt template."""
50- template = self .jinja_env .get_template (f" { template_name } .md.jinja" )
50+ def render_strategy (self ) -> str :
51+ """Render the strategy template."""
52+ template = self .jinja_env .get_template ("STRATEGY .md.jinja" )
5153 return template .render ()
5254
53- def render_game_state (
54- self , state_name : str , game_state : Dict [str , Any ], responses : List [Any ]
55- ) -> str :
55+ def render_gamestate (self , state_name : str , game_state : Dict [str , Any ]) -> str :
5656 """Render the game state template."""
57- template = self .jinja_env .get_template ("game_state .md.jinja" )
57+ template = self .jinja_env .get_template ("GAMESTATE .md.jinja" )
5858 return template .render (
5959 state_name = state_name ,
6060 game_state = game_state ,
61- responses = responses ,
6261 )
6362
63+ def render_memory (self , responses : List [Any ]) -> str :
64+ """Render the memory template."""
65+ template = self .jinja_env .get_template ("MEMORY.md.jinja" )
66+ return template .render (responses = responses )
67+
68+ def load_tools (self ) -> Dict [str , Any ]:
69+ """Load tools from the strategy-specific TOOLS.json file."""
70+ tools_file = self .strategy_dir / "TOOLS.json"
71+ with open (tools_file ) as f :
72+ return json .load (f )
73+
6474
6575@dataclass
6676class Config :
@@ -69,7 +79,7 @@ class Config:
6979 model : str
7080 proxy_url : str = "http://localhost:4000"
7181 api_key : str = "sk-balatrollm-proxy-key"
72- template : str = "system "
82+ template : str = "default "
7383
7484 @classmethod
7585 def from_environment (cls ) -> "Config" :
@@ -78,7 +88,7 @@ def from_environment(cls) -> "Config":
7888 model = os .getenv ("LITELLM_MODEL" , "cerebras-gpt-oss-120b" ),
7989 proxy_url = os .getenv ("LITELLM_PROXY_URL" , "http://localhost:4000" ),
8090 api_key = os .getenv ("LITELLM_API_KEY" , "sk-balatrollm-proxy-key" ),
81- template = os .getenv ("BALATROLLM_TEMPLATE" , "system " ),
91+ template = os .getenv ("BALATROLLM_TEMPLATE" , "default " ),
8292 )
8393
8494
@@ -94,13 +104,11 @@ def __init__(self, config: Config):
94104
95105 # Set up template manager
96106 template_dir = Path (__file__ ).parent / "templates"
97- self .template_manager = TemplateManager (template_dir )
107+ self .template_manager = TemplateManager (template_dir , config . template )
98108 self .responses : list [ChatCompletion ] = []
99109
100- # Load tools from JSON file
101- tools_file = Path (__file__ ).parent / "tools.json"
102- with open (tools_file ) as f :
103- self .tools = json .load (f )
110+ # Load tools from strategy-specific file
111+ self .tools = self .template_manager .load_tools ()
104112
105113 # Get project version from pyproject.toml
106114 self .project_version = self ._get_project_version ()
@@ -252,18 +260,19 @@ async def get_tool_call(self, game_state: dict):
252260
253261 # Generate prompt with error handling
254262 try :
255- system_prompt = self .template_manager .render_system_prompt (
256- self .config .template
257- )
258- user_prompt = self .template_manager .render_game_state (
259- state_name , game_state , self .responses
263+ strategy_content = self .template_manager .render_strategy ()
264+ gamestate_content = self .template_manager .render_gamestate (
265+ state_name , game_state
260266 )
267+ memory_content = self .template_manager .render_memory (self .responses )
261268 except Exception as e :
262269 logger .error (f"Template rendering failed: { e } " )
263270 raise RuntimeError (f"Failed to generate prompts: { e } " ) from e
271+
272+ # Combine all content into user message
273+ user_content = f"{ strategy_content } \n \n { gamestate_content } \n \n { memory_content } "
264274 messages = [
265- {"role" : "system" , "content" : system_prompt },
266- {"role" : "user" , "content" : user_prompt },
275+ {"role" : "user" , "content" : user_content },
267276 ]
268277
269278 # Select tools based on current state
0 commit comments