33from __future__ import annotations
44
55from dataclasses import dataclass
6+ import os
67from typing import Any
78
89from loguru import logger
@@ -44,6 +45,23 @@ class ChatResponseWithTools(ChatResponse):
4445class LLMClient :
4546 """LLM client using LiteLLM (supports Mistral, OpenAI, Anthropic, etc.)."""
4647
48+ _PROVIDER_ENV_KEYS : dict [str , tuple [str , ...]] = {
49+ "mistral" : ("MISTRAL_API_KEY" ,),
50+ "openai" : ("OPENAI_API_KEY" ,),
51+ "anthropic" : ("ANTHROPIC_API_KEY" ,),
52+ "huggingface" : ("HF_TOKEN" , "HUGGINGFACE_API_KEY" ),
53+ "gemini" : ("GEMINI_API_KEY" ,),
54+ }
55+
56+ _FALLBACK_ENV_CHAIN : tuple [str , ...] = (
57+ "MISTRAL_API_KEY" ,
58+ "OPENAI_API_KEY" ,
59+ "ANTHROPIC_API_KEY" ,
60+ "HF_TOKEN" ,
61+ "HUGGINGFACE_API_KEY" ,
62+ "GEMINI_API_KEY" ,
63+ )
64+
4765 def __init__ (
4866 self ,
4967 model : str = "mistral/mistral-large-latest" ,
@@ -62,30 +80,45 @@ def __init__(
6280 self .temperature = temperature
6381 self .max_tokens = max_tokens
6482
83+ def _provider_from_model (self ) -> str :
84+ if "/" not in self .model :
85+ return ""
86+ return self .model .split ("/" , 1 )[0 ].lower ().strip ()
87+
88+ def _resolve_api_key (self ) -> str | None :
89+ """Resolve API key by explicit value, then provider-matched env vars, then generic fallback."""
90+ if self .api_key :
91+ return self .api_key
92+
93+ provider = self ._provider_from_model ()
94+ if self .api_base and "huggingface.co" in self .api_base :
95+ provider = "huggingface"
96+
97+ provider_keys = self ._PROVIDER_ENV_KEYS .get (provider , ())
98+ for env_name in provider_keys :
99+ value = os .environ .get (env_name )
100+ if value :
101+ return value
102+
103+ for env_name in self ._FALLBACK_ENV_CHAIN :
104+ value = os .environ .get (env_name )
105+ if value :
106+ return value
107+
108+ return None
109+
65110 async def chat (
66111 self ,
67112 messages : list [dict [str , str ]],
68113 temperature : float | None = None ,
69114 max_tokens : int | None = None ,
70115 ) -> ChatResponse :
71116 """Send chat messages and return response."""
72- import os
73-
74117 import litellm
75118
76119 temp = temperature if temperature is not None else self .temperature
77120 max_tok = max_tokens if max_tokens is not None else self .max_tokens
78- # Fallback order: explicit then Mistral/OpenAI/Anthropic/HF/Gemini. Consider
79- # making this model-aware (e.g. prefer OPENAI_API_KEY when model starts with openai/).
80- api_key = (
81- self .api_key
82- or os .environ .get ("MISTRAL_API_KEY" )
83- or os .environ .get ("OPENAI_API_KEY" )
84- or os .environ .get ("ANTHROPIC_API_KEY" )
85- or os .environ .get ("HF_TOKEN" )
86- or os .environ .get ("HUGGINGFACE_API_KEY" )
87- or os .environ .get ("GEMINI_API_KEY" )
88- )
121+ api_key = self ._resolve_api_key ()
89122
90123 kwargs : dict [str , Any ] = {
91124 "model" : self .model ,
@@ -125,22 +158,11 @@ async def chat_with_tools(
125158 Send messages with tool definitions; return content and tool_calls.
126159 Does not loop; caller must execute tools, append results, and call again until no tool_calls.
127160 """
128- import os
129-
130161 import litellm
131162
132163 temp = temperature if temperature is not None else self .temperature
133164 max_tok = max_tokens if max_tokens is not None else self .max_tokens
134- # Same fallback order as chat() (see comment there).
135- api_key = (
136- self .api_key
137- or os .environ .get ("MISTRAL_API_KEY" )
138- or os .environ .get ("OPENAI_API_KEY" )
139- or os .environ .get ("ANTHROPIC_API_KEY" )
140- or os .environ .get ("HF_TOKEN" )
141- or os .environ .get ("HUGGINGFACE_API_KEY" )
142- or os .environ .get ("GEMINI_API_KEY" )
143- )
165+ api_key = self ._resolve_api_key ()
144166
145167 kwargs_tools : dict [str , Any ] = {
146168 "model" : self .model ,
0 commit comments