Skip to content

Commit a486e9c

Browse files
JihaoXinclaude
andcommitted
Add real-time token & cost tracking with live dashboard chart
Parses claude --output-format json on every agent call to capture the 4 token types and total_cost_usd, aggregates them live into state/cost_report.yaml after each agent (atomic write), and surfaces them through the existing SSE status stream so the project detail page shows a USD/Tokens chart that updates within ~2s. Closes the latent telegram_daemon read of total_cost_usd that nothing wrote. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 310caab commit a486e9c

File tree

5 files changed

+504
-20
lines changed

5 files changed

+504
-20
lines changed

ark/agents.py

Lines changed: 131 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""AgentMixin: agent execution, output parsing, rate limit handling."""
22
from __future__ import annotations
33

4+
import json
45
import os
56
import re
67
import signal
@@ -11,6 +12,58 @@
1112
from datetime import datetime, timedelta
1213
from pathlib import Path
1314

15+
16+
def _parse_claude_json(stdout: str) -> dict | None:
17+
"""Parse output of `claude --output-format json`. Returns None on any failure.
18+
19+
Tolerates trailing whitespace and the rare case where stdout has leading
20+
non-JSON debug output by scanning for the final result-shaped object.
21+
Never raises — callers fall back to treating stdout as plain text.
22+
"""
23+
text = (stdout or "").strip()
24+
if not text:
25+
return None
26+
try:
27+
return json.loads(text)
28+
except json.JSONDecodeError:
29+
# Last-resort: locate the final result envelope
30+
marker = '{"type":"result"'
31+
start = text.rfind(marker)
32+
if start == -1:
33+
return None
34+
try:
35+
return json.loads(text[start:])
36+
except json.JSONDecodeError:
37+
return None
38+
39+
40+
def _extract_usage(parsed: dict) -> dict:
41+
"""Pull token/cost fields out of parsed claude JSON. Zero-default so callers
42+
don't need null checks. Always returns a complete dict shape."""
43+
parsed = parsed or {}
44+
u = parsed.get("usage") or {}
45+
model_usage = parsed.get("modelUsage") or {}
46+
model = next(iter(model_usage), "")
47+
return {
48+
"model": model,
49+
"input_tokens": int(u.get("input_tokens") or 0),
50+
"output_tokens": int(u.get("output_tokens") or 0),
51+
"cache_read_tokens": int(u.get("cache_read_input_tokens") or 0),
52+
"cache_creation_tokens": int(u.get("cache_creation_input_tokens") or 0),
53+
"cost_usd": float(parsed.get("total_cost_usd") or 0.0),
54+
"duration_api_ms": int(parsed.get("duration_api_ms") or 0),
55+
}
56+
57+
58+
def _fmt_tok(n: int) -> str:
59+
"""Format a token count as compact human-readable (e.g. 12.3k, 1.2M)."""
60+
n = int(n or 0)
61+
if n >= 1_000_000:
62+
return f"{n / 1_000_000:.1f}M"
63+
if n >= 1_000:
64+
return f"{n / 1_000:.1f}k"
65+
return str(n)
66+
1467
from ark.paths import get_config_dir
1568
from ark.ui import (
1669
ElapsedTimer, RateLimitCountdown, agent_styled, styled, Style, Icons,
@@ -434,7 +487,7 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
434487
"claude", "-p", full_prompt,
435488
"--permission-mode", "bypassPermissions",
436489
"--no-session-persistence",
437-
"--output-format", "text",
490+
"--output-format", "json",
438491
"--append-system-prompt", self._build_path_boundary(),
439492
]
440493
ark_model = self._get_ark_model()
@@ -469,10 +522,23 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
469522

470523
timer.start()
471524
result = ""
525+
usage_record = None # populated when claude returns parseable JSON
472526

473527
try:
474528
stdout, stderr = process.communicate(timeout=timeout)
475-
result = stdout
529+
# claude --output-format json: parse the envelope, extract `result`
530+
# field for downstream and `usage` for cost tracking. Fall back to
531+
# raw stdout on parse failure so the existing empty-run / failure
532+
# paths still trigger normally.
533+
if self.model == "claude":
534+
parsed = _parse_claude_json(stdout)
535+
if parsed is not None:
536+
result = parsed.get("result", "") or ""
537+
usage_record = _extract_usage(parsed)
538+
else:
539+
result = stdout
540+
else:
541+
result = stdout
476542

477543
if stderr:
478544
stderr_lower = stderr.lower()
@@ -517,7 +583,17 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
517583
timer.stop()
518584
self.log(f"Agent {agent_type} timed out ({timeout}s)", "WARN")
519585
stdout, _ = process.communicate()
520-
result = stdout
586+
# JSON envelope is usually missing on timeout (truncated mid-stream).
587+
# Try once; on failure fall back to raw text and let empty-run handle it.
588+
if self.model == "claude":
589+
parsed = _parse_claude_json(stdout)
590+
if parsed is not None:
591+
result = parsed.get("result", "") or ""
592+
usage_record = _extract_usage(parsed)
593+
else:
594+
result = stdout
595+
else:
596+
result = stdout
521597

522598
watchdog.stop()
523599
timer.stop()
@@ -589,31 +665,78 @@ def run_agent(self, agent_type: str, task: str, timeout: int = 1800,
589665
start_time = time.time()
590666
continue
591667
self.send_notification("Agent Error Failed", f"{agent_type}: {e}", priority="critical")
592-
self._agent_stats.append({
668+
err_stat = {
593669
"agent_type": agent_type,
594670
"elapsed_seconds": elapsed,
595671
"prompt_len": 0,
596672
"output_len": 0,
597673
"timestamp": datetime.now().isoformat(),
598674
"error": str(e),
599-
})
675+
# Zero-default cost fields so aggregation never sees missing keys
676+
"model": "",
677+
"input_tokens": 0,
678+
"output_tokens": 0,
679+
"cache_read_tokens": 0,
680+
"cache_creation_tokens": 0,
681+
"cost_usd": 0.0,
682+
"duration_api_ms": 0,
683+
}
684+
self._agent_stats.append(err_stat)
685+
try:
686+
self._write_cost_report()
687+
except Exception:
688+
pass
600689
return ""
601690

602691
timer.stop()
603692
self.log_step(f"{Icons.for_agent(agent_type)} {agent_styled(agent_type, f'[{agent_type}]')} completed ({elapsed}s)", "success")
604693

694+
# One-line cost summary (only when claude returned parseable usage)
695+
if usage_record:
696+
in_tok = usage_record["input_tokens"]
697+
out_tok = usage_record["output_tokens"]
698+
cr = usage_record["cache_read_tokens"]
699+
cc = usage_record["cache_creation_tokens"]
700+
cached_in = cr + cc
701+
total_in = in_tok + cached_in
702+
hit_pct = int(100 * cr / total_in) if total_in else 0
703+
self.log_step(
704+
f" 💰 ${usage_record['cost_usd']:.4f} "
705+
f"in:{_fmt_tok(in_tok)} out:{_fmt_tok(out_tok)} "
706+
f"cache:{_fmt_tok(cached_in)}({hit_pct}% hit)",
707+
"info"
708+
)
709+
605710
# Agent summary
606711
summary_items = self._summarize_agent_output(agent_type, result)
607712
if summary_items:
608713
self.log_summary_box(f"{agent_type.upper()} Summary", summary_items)
609714

610-
# Cost tracking
611-
self._agent_stats.append({
715+
# Cost tracking — extend with real token/cost when claude JSON was parsed
716+
stat = {
612717
"agent_type": agent_type,
613718
"elapsed_seconds": elapsed,
614719
"prompt_len": len(full_prompt),
615720
"output_len": len(result) if result else 0,
616721
"timestamp": datetime.now().isoformat(),
617-
})
722+
# Zero-defaults so cost_report aggregation never sees missing keys
723+
"model": "",
724+
"input_tokens": 0,
725+
"output_tokens": 0,
726+
"cache_read_tokens": 0,
727+
"cache_creation_tokens": 0,
728+
"cost_usd": 0.0,
729+
"duration_api_ms": 0,
730+
}
731+
if usage_record:
732+
stat.update(usage_record)
733+
self._agent_stats.append(stat)
734+
735+
# Live cost report — written after every agent so the webapp SSE stream
736+
# can pick up updates within ~2s. Failures here must never break the run.
737+
try:
738+
self._write_cost_report()
739+
except Exception as exc:
740+
self.log(f" cost report write failed: {exc}", "WARN")
618741

619742
return result

ark/pipeline.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,36 +1755,79 @@ def _send_dev_phase_telegram(self, event: str, current: int, total: int):
17551755
pass
17561756

17571757
def _write_cost_report(self):
1758-
"""Write per-agent and total cost/stats to cost_report.yaml."""
1758+
"""Write per-agent and total cost/stats to cost_report.yaml.
1759+
1760+
Called after every agent invocation so the webapp SSE stream can pick
1761+
up live updates within ~2s. Writes atomically (.tmp + os.replace) so
1762+
readers never see a partial file. Aggregates real token & USD fields
1763+
when the claude JSON envelope was parsed; falls back to character
1764+
counts otherwise.
1765+
"""
17591766
if not self._agent_stats:
17601767
return
17611768

1762-
# Aggregate per agent type
1769+
# Aggregate per agent type. Each bucket carries both legacy char-count
1770+
# fields (for backwards compat with telegram_daemon / older tests) and
1771+
# the new token + cost fields populated from claude JSON output.
17631772
by_type = {}
17641773
for stat in self._agent_stats:
17651774
atype = stat["agent_type"]
17661775
if atype not in by_type:
1767-
by_type[atype] = {"calls": 0, "total_seconds": 0, "total_prompt_len": 0, "total_output_len": 0}
1768-
by_type[atype]["calls"] += 1
1769-
by_type[atype]["total_seconds"] += stat.get("elapsed_seconds", 0)
1770-
by_type[atype]["total_prompt_len"] += stat.get("prompt_len", 0)
1771-
by_type[atype]["total_output_len"] += stat.get("output_len", 0)
1776+
by_type[atype] = {
1777+
"calls": 0,
1778+
"total_seconds": 0,
1779+
"total_prompt_len": 0,
1780+
"total_output_len": 0,
1781+
"total_input_tokens": 0,
1782+
"total_output_tokens": 0,
1783+
"total_cache_read_tokens": 0,
1784+
"total_cache_creation_tokens": 0,
1785+
"total_cost_usd": 0.0,
1786+
}
1787+
b = by_type[atype]
1788+
b["calls"] += 1
1789+
b["total_seconds"] += stat.get("elapsed_seconds", 0)
1790+
b["total_prompt_len"] += stat.get("prompt_len", 0)
1791+
b["total_output_len"] += stat.get("output_len", 0)
1792+
b["total_input_tokens"] += stat.get("input_tokens", 0)
1793+
b["total_output_tokens"] += stat.get("output_tokens", 0)
1794+
b["total_cache_read_tokens"] += stat.get("cache_read_tokens", 0)
1795+
b["total_cache_creation_tokens"] += stat.get("cache_creation_tokens", 0)
1796+
b["total_cost_usd"] += float(stat.get("cost_usd", 0.0) or 0.0)
17721797

17731798
total_calls = sum(d["calls"] for d in by_type.values())
17741799
total_time = sum(d["total_seconds"] for d in by_type.values())
1800+
total_cost_usd = sum(d["total_cost_usd"] for d in by_type.values())
1801+
total_input_tokens = sum(d["total_input_tokens"] for d in by_type.values())
1802+
total_output_tokens = sum(d["total_output_tokens"] for d in by_type.values())
1803+
total_cache_read_tokens = sum(d["total_cache_read_tokens"] for d in by_type.values())
1804+
total_cache_creation_tokens = sum(d["total_cache_creation_tokens"] for d in by_type.values())
17751805

17761806
report = {
17771807
"generated_at": datetime.now().isoformat(),
17781808
"total_agent_calls": total_calls,
17791809
"total_agent_seconds": total_time,
1810+
"total_cost_usd": round(total_cost_usd, 6),
1811+
"total_input_tokens": total_input_tokens,
1812+
"total_output_tokens": total_output_tokens,
1813+
"total_cache_read_tokens": total_cache_read_tokens,
1814+
"total_cache_creation_tokens": total_cache_creation_tokens,
17801815
"per_agent": by_type,
17811816
"raw_stats": self._agent_stats[-100:], # Keep last 100 entries
17821817
}
17831818

17841819
report_path = self.state_dir / "cost_report.yaml"
1785-
with open(report_path, "w") as f:
1786-
yaml.dump(report, f, default_flow_style=False, allow_unicode=True)
1787-
self.log(f"Cost report written: {report_path} ({total_calls} calls, {total_time}s total)", "INFO")
1820+
tmp_path = report_path.with_suffix(".yaml.tmp")
1821+
try:
1822+
with open(tmp_path, "w") as f:
1823+
yaml.dump(report, f, default_flow_style=False, allow_unicode=True)
1824+
os.replace(tmp_path, report_path)
1825+
except Exception:
1826+
try:
1827+
tmp_path.unlink(missing_ok=True)
1828+
except Exception:
1829+
pass
1830+
raise
17881831

17891832
def run(self):
17901833
"""Main loop."""

ark/webapp/routes.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,33 @@ def _read_current_iteration(project_dir: Path) -> int:
483483
return 0
484484

485485

486+
def _read_cost_report(project_dir: Path) -> dict:
487+
"""Read cost_report.yaml and return a slim summary suitable for the UI.
488+
489+
Drops `raw_stats` (last-100 entries) to keep the payload small for SSE.
490+
Returns an empty dict if the file is missing or unreadable — the UI
491+
treats that as "no cost data yet".
492+
"""
493+
p = project_dir / "auto_research" / "state" / "cost_report.yaml"
494+
if not p.exists():
495+
return {}
496+
try:
497+
d = yaml.safe_load(p.read_text()) or {}
498+
except Exception:
499+
return {}
500+
return {
501+
"total_cost_usd": d.get("total_cost_usd", 0),
502+
"total_input_tokens": d.get("total_input_tokens", 0),
503+
"total_output_tokens": d.get("total_output_tokens", 0),
504+
"total_cache_read_tokens": d.get("total_cache_read_tokens", 0),
505+
"total_cache_creation_tokens": d.get("total_cache_creation_tokens", 0),
506+
"total_agent_calls": d.get("total_agent_calls", 0),
507+
"total_agent_seconds": d.get("total_agent_seconds", 0),
508+
"per_agent": d.get("per_agent", {}),
509+
"generated_at": d.get("generated_at"),
510+
}
511+
512+
486513
_TEMPLATE_TITLES = {"Paper Title", "Title Text", "Insert Title Here", ""}
487514

488515
def _read_paper_title(project_dir: Path) -> str:
@@ -1146,6 +1173,7 @@ async def api_get_project(project_id: str, request: Request):
11461173
"environment": "ROCS Testbed" if project.slurm_job_id and not project.slurm_job_id.startswith("local") else "Local",
11471174
"created_at": project.created_at.isoformat(),
11481175
"updated_at": project.updated_at.isoformat(),
1176+
"cost_report": _read_cost_report(pdir),
11491177
})
11501178

11511179

@@ -1524,12 +1552,17 @@ async def event_generator():
15241552
except Exception:
15251553
pass
15261554

1527-
# Also emit status
1555+
# Also emit status (includes live cost report so the dashboard
1556+
# cost panel updates within ~2s of every agent completion)
15281557
with get_session(settings.db_path) as session:
15291558
p = get_project(session, project_id)
15301559
if p:
15311560
score = _read_project_score(pdir)
1532-
payload = json.dumps({"status": p.status, "score": score})
1561+
payload = json.dumps({
1562+
"status": p.status,
1563+
"score": score,
1564+
"cost_report": _read_cost_report(pdir),
1565+
})
15331566
yield f"event: status\ndata: {payload}\n\n"
15341567

15351568
await asyncio.sleep(2)

0 commit comments

Comments
 (0)