Skip to content

Commit 5b37948

Browse files
authored
Merge pull request #3 from kaust-ark/gemini-support
Add Gemini support as alternative to Claude
2 parents 98f2b48 + d1376bd commit 5b37948

File tree

13 files changed

+509
-31
lines changed

13 files changed

+509
-31
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ What you get:
271271

272272
```bash
273273
# Set up the conda base environment
274-
conda env create -f environment.yml # creates "ark-base"
274+
conda env create -f environment.yml # Linux (creates "ark-base")
275+
# OR for macOS:
276+
conda env create -f environment-macos.yml # macOS (creates "ark-base")
275277

276278
pip install -e . # Core
277279
pip install -e ".[research]" # + Gemini Deep Research & Nano Banana

README_ar.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ ark setup-bot # مرة واحدة: الصق رمز BotFather، كشف تلق
271271

272272
```bash
273273
# إنشاء بيئة conda الأساسية
274-
conda env create -f environment.yml # ينشئ "ark-base"
274+
conda env create -f environment.yml # Linux (ينشئ "ark-base")
275+
# أو لنظام macOS:
276+
conda env create -f environment-macos.yml # macOS (ينشئ "ark-base")
275277

276278
pip install -e . # الأساسي
277279
pip install -e ".[research]" # + Gemini Deep Research و Nano Banana

README_zh.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ ark setup-bot # 一次性配置:粘贴 BotFather token,自动检测 chat
271271

272272
```bash
273273
# 创建 conda 基础环境
274-
conda env create -f environment.yml # 创建 "ark-base"
274+
conda env create -f environment.yml # Linux (创建 "ark-base")
275+
# 或者对于 macOS:
276+
conda env create -f environment-macos.yml # macOS (创建 "ark-base")
275277

276278
pip install -e . # 核心
277279
pip install -e ".[research]" # + Gemini Deep Research 和 Nano Banana

ark/agents.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,87 @@ def _extract_usage(parsed: dict) -> dict:
5555
}
5656

5757

58+
def _parse_gemini_json(stdout: str) -> dict | None:
59+
"""Parse output of `gemini -o json`. Returns None on failure."""
60+
text = (stdout or "").strip()
61+
if not text:
62+
return None
63+
try:
64+
# gemini -o json usually outputs the JSON object directly,
65+
# but may have leading "Loaded cached credentials" etc.
66+
if "{" in text:
67+
text = text[text.find("{"):]
68+
return json.loads(text)
69+
except json.JSONDecodeError:
70+
return None
71+
72+
73+
def _calculate_gemini_cost(model_id: str, input_tok: int, output_tok: int) -> float:
74+
"""
75+
Calculate estimated cost for Gemini models (April 2026 pricing).
76+
"""
77+
input_tok = int(input_tok or 0)
78+
output_tok = int(output_tok or 0)
79+
model_lower = (model_id or "").lower()
80+
81+
# Pricing per 1M tokens
82+
if "3.1-pro" in model_lower:
83+
in_rate = 2.00
84+
out_rate = 12.00
85+
elif "3.1-flash" in model_lower:
86+
in_rate = 0.50
87+
out_rate = 3.00
88+
else:
89+
# Default to pro
90+
in_rate = 2.00
91+
out_rate = 12.00
92+
93+
return (input_tok / 1_000_000 * in_rate) + (output_tok / 1_000_000 * out_rate)
94+
95+
96+
def _extract_gemini_usage(parsed: dict) -> dict:
97+
"""Aggregate token usage info from Gemini CLI's nested stats schema."""
98+
parsed = parsed or {}
99+
stats = parsed.get("stats", {})
100+
models = stats.get("models") or stats.get("model") or {}
101+
102+
total_in = 0
103+
total_out = 0
104+
total_cached = 0
105+
total_thoughts = 0
106+
total_latency = 0
107+
main_model = ""
108+
109+
for mid, info in models.items():
110+
t = info.get("tokens", {})
111+
total_in += int(t.get("input") or 0)
112+
total_out += int(t.get("candidates") or 0)
113+
total_cached += int(t.get("cached") or 0)
114+
total_thoughts += int(t.get("thoughts") or 0)
115+
116+
api = info.get("api", {})
117+
total_latency += int(api.get("totalLatencyMs") or 0)
118+
119+
# Heuristic for the "main" model being used for the response
120+
if "roles" in info and "main" in info["roles"]:
121+
main_model = mid
122+
123+
if not main_model and models:
124+
main_model = next(iter(models))
125+
126+
cost_usd = _calculate_gemini_cost(main_model, total_in, total_out)
127+
128+
return {
129+
"model": main_model,
130+
"input_tokens": total_in,
131+
"output_tokens": total_out,
132+
"cache_read_tokens": total_cached,
133+
"cache_creation_tokens": 0, # Gemini schema doesn't distinguish creation
134+
"cost_usd": cost_usd,
135+
"duration_api_ms": total_latency,
136+
}
137+
138+
58139
def _fmt_tok(n: int) -> str:
59140
"""Format a token count as compact human-readable (e.g. 12.3k, 1.2M)."""
60141
n = int(n or 0)
@@ -485,16 +566,26 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
485566
for attempt in range(1, MAX_RETRIES + 1):
486567
try:
487568
cmd = []
488-
# Strip CLAUDECODE env var to prevent nested-session detection
489-
env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"}
569+
# Strip CLAUDECODE to prevent nested-session detection.
570+
# Strip GEMINI_API_KEY / GOOGLE_API_KEY so the Gemini CLI uses
571+
# OAuth credentials from ~/.gemini/oauth_creds.json rather than
572+
# the API key (which is only for Deep Research via Python API).
573+
_strip = {"CLAUDECODE", "GEMINI_API_KEY", "GOOGLE_API_KEY"}
574+
env = {k: v for k, v in os.environ.items() if k not in _strip}
490575
if self.model == "gemini":
491576
boundary = self._build_path_boundary()
492577
cmd = [
493578
"gemini",
494-
"-m", "auto",
579+
"-p", f"[SYSTEM RULE] {boundary}\n\n{full_prompt}",
495580
"--approval-mode", "auto_edit",
496-
f"[SYSTEM RULE] {boundary}\n\n{full_prompt}",
581+
"-o", "json",
497582
]
583+
# Respect model_variant if set
584+
ark_model = self._get_ark_model()
585+
if ark_model:
586+
cmd.extend(["-m", ark_model])
587+
else:
588+
cmd.extend(["-m", "auto"])
498589
elif self.model == "claude":
499590
cmd = [
500591
"claude", "-p", full_prompt,
@@ -518,6 +609,9 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
518609
self.log(f"Unsupported model backend: {self.model}", "ERROR")
519610
return ""
520611

612+
ark_model = self._get_ark_model()
613+
self.log(f"Backend model: {self.model} | Model: {ark_model or 'default'}", "INFO")
614+
521615
process = subprocess.Popen(
522616
cmd,
523617
stdin=subprocess.DEVNULL, # Don't hold terminal pty fd
@@ -552,6 +646,13 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
552646
usage_record = _extract_usage(parsed)
553647
else:
554648
result = stdout
649+
elif self.model == "gemini":
650+
parsed = _parse_gemini_json(stdout)
651+
if parsed is not None:
652+
result = parsed.get("response", "") or ""
653+
usage_record = _extract_gemini_usage(parsed)
654+
else:
655+
result = stdout
555656
else:
556657
result = stdout
557658

@@ -614,6 +715,13 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
614715
usage_record = _extract_usage(parsed)
615716
else:
616717
result = stdout
718+
elif self.model == "gemini":
719+
parsed = _parse_gemini_json(stdout)
720+
if parsed is not None:
721+
result = parsed.get("response", "") or ""
722+
usage_record = _extract_gemini_usage(parsed)
723+
else:
724+
result = stdout
617725
else:
618726
result = stdout
619727

ark/orchestrator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ class Orchestrator(AgentMixin, CompilerMixin, ExecutionMixin, PipelineMixin, Dev
4343
"""Main orchestrator class composing all mixins."""
4444

4545
def __init__(self, project: str, max_days: float = 3, max_iterations: int = 100,
46-
mode: str = "research", model: str = "claude", code_dir: str = None,
46+
mode: str = "research", model: str = None, code_dir: str = None,
4747
project_dir: str = None, db_path: str = None, project_id: str = None):
4848
global PROJECT_DIR
4949

5050
self.max_end_time = datetime.now() + timedelta(days=max_days)
5151
self.max_iterations = max_iterations
5252
self.iteration = 0
5353
self.mode = mode
54-
self.model = model
54+
self._model_arg = model # Store the CLI/constructor argument
5555
self.project_name = project
5656

5757
# ── DB awareness ──
@@ -73,10 +73,13 @@ def __init__(self, project: str, max_days: float = 3, max_iterations: int = 100,
7373
config_file = self.project_path / "config.yaml"
7474
if config_file.exists():
7575
with open(config_file) as f:
76-
self.config = yaml.safe_load(f)
76+
self.config = yaml.safe_load(f) or {}
7777
else:
7878
self.config = {}
7979

80+
# Resolve model: Argument > config.yaml > fallback to "claude"
81+
self.model = self._model_arg or self.config.get("model") or "claude"
82+
8083
# Set code_dir and legacy global PROJECT_DIR
8184
if code_dir:
8285
PROJECT_DIR = Path(code_dir).absolute()
@@ -232,6 +235,11 @@ def _sync_db(self, **kwargs):
232235
"""Update project record in the webapp DB. Fail-soft: errors are logged, never raised."""
233236
if not self._db_path or not self._project_id:
234237
return
238+
try:
239+
import sqlalchemy # noqa: F401 — availability check
240+
except ImportError:
241+
self._db_path = None # disable future sync attempts silently
242+
return
235243
try:
236244
from ark.webapp.db import get_session, get_project, update_project
237245
with get_session(self._db_path) as session:
@@ -2088,7 +2096,7 @@ def main():
20882096
parser.add_argument("--mode", type=str, default="research", choices=["research", "paper", "dev"],
20892097
help="Mode: 'research' for experiments, 'paper' for review iterations, 'dev' for development iterations")
20902098
parser.add_argument("--project", type=str, required=True, help="Project name (e.g., prouter)")
2091-
parser.add_argument("--model", type=str, default="claude", choices=["claude", "gemini", "codex"],
2099+
parser.add_argument("--model", type=str, default=None, choices=["claude", "gemini", "codex"],
20922100
help="Model backend: 'claude', 'gemini', or 'codex'")
20932101
parser.add_argument("--max-days", type=float, default=3, help="Maximum runtime in days")
20942102
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations to run")

ark/webapp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .app import create_app
44

5-
__all__ = ["create_app"]
5+
__all__ = ["create_app"]

0 commit comments

Comments
 (0)