Skip to content

Commit 1983398

Browse files
stabganabidlabs
andauthored
Optimizations for the exponential backoff and for indexing the sqlite db (#70)
* Refactor utils and improve imports * Optimize database and UI performance * Format run and sqlite_storage * revert utils changes for now * revert slots * revert imports changes * revert ui changes * changes * readme * changes * projects * docstring * changes * changes * rm project name from db * use Path * changes * remove cache * revert * changes * use index * fix tests * version bump --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
1 parent 6670a12 commit 1983398

10 files changed

Lines changed: 95 additions & 82 deletions

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ To get started and see basic examples of usage, see these files:
145145
* [Persisting metrics in a Hugging Face Dataset](https://github.com/gradio-app/trackio/blob/main/examples/persist-dataset.py)
146146
* [Deploying the dashboard to Spaces](https://github.com/gradio-app/trackio/blob/main/examples/deploy-on-spaces.py)
147147

148+
## Note: Trackio is in Beta (DB Schema May Change)
149+
150+
Note that Trackio is in pre-release right now and we may release breaking changes. In particular, the schema of the Trackio sqlite database may change, which may require migrating or deleting existing database files (located by default at: `~/.cache/huggingface/trackio`).
151+
152+
Since Trackio is in beta, your feedback is welcome! Please create issues with bug reports or feature requests.
153+
148154
## License
149155

150156
MIT License
67 Bytes
Binary file not shown.
19 Bytes
Binary file not shown.
Binary file not shown.

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tempfile
2+
from pathlib import Path
23

34
import pytest
45

@@ -7,5 +8,5 @@
78
def temp_db(monkeypatch):
89
"""Fixture that creates a temporary directory for database storage and patches the TRACKIO_DIR."""
910
with tempfile.TemporaryDirectory() as tmpdir:
10-
monkeypatch.setattr("trackio.sqlite_storage.TRACKIO_DIR", tmpdir)
11+
monkeypatch.setattr("trackio.sqlite_storage.TRACKIO_DIR", Path(tmpdir))
1112
yield tmpdir

trackio/deploy.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,14 @@ def wait_until_space_exists(
128128
Args:
129129
space_id: The ID of the Space to wait for.
130130
"""
131-
client = None
132-
for _ in range(30):
131+
delay = 1
132+
for _ in range(10):
133133
try:
134-
client = Client(space_id, verbose=False)
135-
if client:
136-
break
137-
except ReadTimeout:
138-
time.sleep(5)
139-
except ValueError:
140-
time.sleep(5)
134+
Client(space_id, verbose=False)
135+
return
136+
except (ReadTimeout, ValueError):
137+
time.sleep(delay)
138+
delay = min(delay * 2, 30)
141139
raise TimeoutError("Waiting for space to exist took longer than expected")
142140

143141

trackio/sqlite_storage.py

Lines changed: 77 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,64 @@
1-
import glob
21
import json
32
import os
43
import sqlite3
54
from datetime import datetime
5+
from pathlib import Path
66

77
from huggingface_hub import CommitScheduler
88

9-
try:
9+
try: # absolute imports when installed
1010
from trackio.context_vars import current_scheduler
1111
from trackio.dummy_commit_scheduler import DummyCommitScheduler
1212
from trackio.utils import TRACKIO_DIR
13-
except: # noqa: E722
13+
except Exception: # relative imports for local execution on Spaces
1414
from context_vars import current_scheduler
1515
from dummy_commit_scheduler import DummyCommitScheduler
1616
from utils import TRACKIO_DIR
1717

1818

1919
class SQLiteStorage:
2020
@staticmethod
21-
def get_project_db_path(project: str) -> str:
21+
def _get_connection(db_path: Path) -> sqlite3.Connection:
22+
conn = sqlite3.connect(str(db_path))
23+
conn.row_factory = sqlite3.Row
24+
return conn
25+
26+
@staticmethod
27+
def get_project_db_path(project: str) -> Path:
2228
"""Get the database path for a specific project."""
2329
safe_project_name = "".join(
2430
c for c in project if c.isalnum() or c in ("-", "_")
2531
).rstrip()
2632
if not safe_project_name:
2733
safe_project_name = "default"
28-
return os.path.join(TRACKIO_DIR, f"{safe_project_name}.db")
34+
return TRACKIO_DIR / f"{safe_project_name}.db"
2935

3036
@staticmethod
31-
def init_db(project: str) -> str:
37+
def init_db(project: str) -> Path:
3238
"""
3339
Initialize the SQLite database with required tables.
3440
Returns the database path.
3541
"""
3642
db_path = SQLiteStorage.get_project_db_path(project)
37-
os.makedirs(os.path.dirname(db_path), exist_ok=True)
43+
db_path.parent.mkdir(parents=True, exist_ok=True)
3844
with SQLiteStorage.get_scheduler().lock:
3945
with sqlite3.connect(db_path) as conn:
4046
cursor = conn.cursor()
4147
cursor.execute("""
4248
CREATE TABLE IF NOT EXISTS metrics (
4349
id INTEGER PRIMARY KEY AUTOINCREMENT,
4450
timestamp TEXT NOT NULL,
45-
project_name TEXT NOT NULL,
4651
run_name TEXT NOT NULL,
4752
step INTEGER NOT NULL,
4853
metrics TEXT NOT NULL
4954
)
5055
""")
56+
cursor.execute(
57+
"""
58+
CREATE INDEX IF NOT EXISTS idx_metrics_run_step
59+
ON metrics(run_name, step)
60+
"""
61+
)
5162
conn.commit()
5263
return db_path
5364

@@ -85,16 +96,16 @@ def log(project: str, run: str, metrics: dict):
8596
db_path = SQLiteStorage.init_db(project)
8697

8798
with SQLiteStorage.get_scheduler().lock:
88-
with sqlite3.connect(db_path) as conn:
99+
with SQLiteStorage._get_connection(db_path) as conn:
89100
cursor = conn.cursor()
90101

91102
cursor.execute(
92103
"""
93104
SELECT MAX(step)
94105
FROM metrics
95-
WHERE project_name = ? AND run_name = ?
106+
WHERE run_name = ?
96107
""",
97-
(project, run),
108+
(run,),
98109
)
99110
last_step = cursor.fetchone()[0]
100111
current_step = 0 if last_step is None else last_step + 1
@@ -103,13 +114,12 @@ def log(project: str, run: str, metrics: dict):
103114

104115
cursor.execute(
105116
"""
106-
INSERT INTO metrics
107-
(timestamp, project_name, run_name, step, metrics)
108-
VALUES (?, ?, ?, ?, ?)
117+
INSERT INTO metrics
118+
(timestamp, run_name, step, metrics)
119+
VALUES (?, ?, ?, ?)
109120
""",
110121
(
111122
current_timestamp,
112-
project,
113123
run,
114124
current_step,
115125
json.dumps(metrics),
@@ -142,15 +152,14 @@ def bulk_log(
142152

143153
db_path = SQLiteStorage.init_db(project)
144154
with SQLiteStorage.get_scheduler().lock:
145-
with sqlite3.connect(db_path) as conn:
155+
with SQLiteStorage._get_connection(db_path) as conn:
146156
cursor = conn.cursor()
147157

148158
data = []
149159
for i, metrics in enumerate(metrics_list):
150160
data.append(
151161
(
152162
timestamps[i],
153-
project,
154163
run,
155164
steps[i],
156165
json.dumps(metrics),
@@ -159,9 +168,9 @@ def bulk_log(
159168

160169
cursor.executemany(
161170
"""
162-
INSERT INTO metrics
163-
(timestamp, project_name, run_name, step, metrics)
164-
VALUES (?, ?, ?, ?, ?)
171+
INSERT INTO metrics
172+
(timestamp, run_name, step, metrics)
173+
VALUES (?, ?, ?, ?)
165174
""",
166175
data,
167176
)
@@ -171,71 +180,85 @@ def bulk_log(
171180
def get_metrics(project: str, run: str) -> list[dict]:
172181
"""Retrieve metrics for a specific run. The metrics also include the step count (int) and the timestamp (datetime object)."""
173182
db_path = SQLiteStorage.get_project_db_path(project)
174-
if not os.path.exists(db_path):
183+
if not db_path.exists():
175184
return []
176185

177-
with sqlite3.connect(db_path) as conn:
186+
with SQLiteStorage._get_connection(db_path) as conn:
178187
cursor = conn.cursor()
179188
cursor.execute(
180189
"""
181190
SELECT timestamp, step, metrics
182191
FROM metrics
183-
WHERE project_name = ? AND run_name = ?
192+
WHERE run_name = ?
184193
ORDER BY timestamp
185194
""",
186-
(project, run),
195+
(run,),
187196
)
188-
rows = cursor.fetchall()
189197

198+
rows = cursor.fetchall()
190199
results = []
191200
for row in rows:
192-
timestamp, step, metrics_json = row
193-
metrics = json.loads(metrics_json)
194-
metrics["timestamp"] = timestamp
195-
metrics["step"] = step
201+
metrics = json.loads(row["metrics"])
202+
metrics["timestamp"] = row["timestamp"]
203+
metrics["step"] = row["step"]
196204
results.append(metrics)
205+
197206
return results
198207

199208
@staticmethod
200209
def get_projects() -> list[str]:
201-
"""Get list of all projects by scanning database files."""
202-
projects = []
203-
if not os.path.exists(TRACKIO_DIR):
204-
return projects
205-
206-
db_files = glob.glob(os.path.join(TRACKIO_DIR, "*.db"))
207-
208-
for db_file in db_files:
209-
try:
210-
with sqlite3.connect(db_file) as conn:
211-
cursor = conn.cursor()
212-
cursor.execute(
213-
"SELECT name FROM sqlite_master WHERE type='table' AND name='metrics'"
214-
)
215-
if cursor.fetchone():
216-
cursor.execute("SELECT DISTINCT project_name FROM metrics")
217-
project_names = [row[0] for row in cursor.fetchall()]
218-
projects.extend(project_names)
219-
except sqlite3.Error:
220-
continue
210+
"""
211+
Get list of all projects by scanning the database files in the trackio directory.
212+
"""
213+
projects: set[str] = set()
214+
if not TRACKIO_DIR.exists():
215+
return []
221216

222-
return list(set(projects))
217+
for db_file in TRACKIO_DIR.glob("*.db"):
218+
project_name = db_file.stem
219+
projects.add(project_name)
220+
return sorted(projects)
223221

224222
@staticmethod
225223
def get_runs(project: str) -> list[str]:
226224
"""Get list of all runs for a project."""
227225
db_path = SQLiteStorage.get_project_db_path(project)
228-
if not os.path.exists(db_path):
226+
if not db_path.exists():
229227
return []
230228

231-
with sqlite3.connect(db_path) as conn:
229+
with SQLiteStorage._get_connection(db_path) as conn:
232230
cursor = conn.cursor()
233231
cursor.execute(
234-
"SELECT DISTINCT run_name FROM metrics WHERE project_name = ?",
235-
(project,),
232+
"SELECT DISTINCT run_name FROM metrics",
236233
)
237234
return [row[0] for row in cursor.fetchall()]
238235

236+
@staticmethod
237+
def get_max_steps_for_runs(project: str, runs: list[str]) -> dict[str, int]:
238+
"""Efficiently get the maximum step for multiple runs in a single query."""
239+
db_path = SQLiteStorage.get_project_db_path(project)
240+
if not db_path.exists():
241+
return {run: 0 for run in runs}
242+
243+
with SQLiteStorage._get_connection(db_path) as conn:
244+
cursor = conn.cursor()
245+
placeholders = ",".join("?" * len(runs))
246+
cursor.execute(
247+
f"""
248+
SELECT run_name, MAX(step) as max_step
249+
FROM metrics
250+
WHERE run_name IN ({placeholders})
251+
GROUP BY run_name
252+
""",
253+
runs,
254+
)
255+
256+
results = {run: 0 for run in runs} # Default to 0 for runs with no data
257+
for row in cursor.fetchall():
258+
results[row["run_name"]] = row["max_step"]
259+
260+
return results
261+
239262
def finish(self):
240263
"""Cleanup when run is finished."""
241264
pass

trackio/ui.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -408,21 +408,7 @@ def update_last_steps(project, runs):
408408
if not project or not runs:
409409
return {}
410410

411-
last_steps = {}
412-
for run in runs:
413-
metrics = SQLiteStorage.get_metrics(project, run)
414-
if metrics:
415-
df = pd.DataFrame(metrics)
416-
if "step" not in df.columns:
417-
df["step"] = range(len(df))
418-
if not df.empty:
419-
last_steps[run] = df["step"].max().item()
420-
else:
421-
last_steps[run] = 0
422-
else:
423-
last_steps[run] = 0
424-
425-
return last_steps
411+
return SQLiteStorage.get_max_steps_for_runs(project, runs)
426412

427413
timer.tick(
428414
fn=update_last_steps,

trackio/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import random
32
import re
43
import sys
@@ -9,7 +8,7 @@
98
from huggingface_hub.constants import HF_HOME
109

1110
RESERVED_KEYS = ["project", "run", "timestamp", "step", "time"]
12-
TRACKIO_DIR = os.path.join(HF_HOME, "trackio")
11+
TRACKIO_DIR = Path(HF_HOME) / "trackio"
1312

1413
TRACKIO_LOGO_PATH = str(Path(__file__).parent.joinpath("trackio_logo.png"))
1514

trackio/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.19
1+
0.0.20

0 commit comments

Comments
 (0)