1- import glob
21import json
32import os
43import sqlite3
54from datetime import datetime
5+ from pathlib import Path
66
77from huggingface_hub import CommitScheduler
88
9- try :
9+ try : # absolute imports when installed
1010 from trackio .context_vars import current_scheduler
1111 from trackio .dummy_commit_scheduler import DummyCommitScheduler
1212 from trackio .utils import TRACKIO_DIR
13- except : # noqa: E722
13+ except Exception : # relative imports for local execution on Spaces
1414 from context_vars import current_scheduler
1515 from dummy_commit_scheduler import DummyCommitScheduler
1616 from utils import TRACKIO_DIR
1717
1818
1919class SQLiteStorage :
2020 @staticmethod
21- def get_project_db_path (project : str ) -> str :
21+ def _get_connection (db_path : Path ) -> sqlite3 .Connection :
22+ conn = sqlite3 .connect (str (db_path ))
23+ conn .row_factory = sqlite3 .Row
24+ return conn
25+
26+ @staticmethod
27+ def get_project_db_path (project : str ) -> Path :
2228 """Get the database path for a specific project."""
2329 safe_project_name = "" .join (
2430 c for c in project if c .isalnum () or c in ("-" , "_" )
2531 ).rstrip ()
2632 if not safe_project_name :
2733 safe_project_name = "default"
28- return os . path . join ( TRACKIO_DIR , f"{ safe_project_name } .db" )
34+ return TRACKIO_DIR / f"{ safe_project_name } .db"
2935
3036 @staticmethod
31- def init_db (project : str ) -> str :
37+ def init_db (project : str ) -> Path :
3238 """
3339 Initialize the SQLite database with required tables.
3440 Returns the database path.
3541 """
3642 db_path = SQLiteStorage .get_project_db_path (project )
37- os . makedirs ( os . path . dirname ( db_path ) , exist_ok = True )
43+ db_path . parent . mkdir ( parents = True , exist_ok = True )
3844 with SQLiteStorage .get_scheduler ().lock :
3945 with sqlite3 .connect (db_path ) as conn :
4046 cursor = conn .cursor ()
4147 cursor .execute ("""
4248 CREATE TABLE IF NOT EXISTS metrics (
4349 id INTEGER PRIMARY KEY AUTOINCREMENT,
4450 timestamp TEXT NOT NULL,
45- project_name TEXT NOT NULL,
4651 run_name TEXT NOT NULL,
4752 step INTEGER NOT NULL,
4853 metrics TEXT NOT NULL
4954 )
5055 """ )
56+ cursor .execute (
57+ """
58+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
59+ ON metrics(run_name, step)
60+ """
61+ )
5162 conn .commit ()
5263 return db_path
5364
@@ -85,16 +96,16 @@ def log(project: str, run: str, metrics: dict):
8596 db_path = SQLiteStorage .init_db (project )
8697
8798 with SQLiteStorage .get_scheduler ().lock :
88- with sqlite3 . connect (db_path ) as conn :
99+ with SQLiteStorage . _get_connection (db_path ) as conn :
89100 cursor = conn .cursor ()
90101
91102 cursor .execute (
92103 """
93104 SELECT MAX(step)
94105 FROM metrics
95- WHERE project_name = ? AND run_name = ?
106+ WHERE run_name = ?
96107 """ ,
97- (project , run ),
108+ (run , ),
98109 )
99110 last_step = cursor .fetchone ()[0 ]
100111 current_step = 0 if last_step is None else last_step + 1
@@ -103,13 +114,12 @@ def log(project: str, run: str, metrics: dict):
103114
104115 cursor .execute (
105116 """
106- INSERT INTO metrics
107- (timestamp, project_name, run_name, step, metrics)
108- VALUES (?, ?, ?, ?, ? )
117+ INSERT INTO metrics
118+ (timestamp, run_name, step, metrics)
119+ VALUES (?, ?, ?, ?)
109120 """ ,
110121 (
111122 current_timestamp ,
112- project ,
113123 run ,
114124 current_step ,
115125 json .dumps (metrics ),
@@ -142,15 +152,14 @@ def bulk_log(
142152
143153 db_path = SQLiteStorage .init_db (project )
144154 with SQLiteStorage .get_scheduler ().lock :
145- with sqlite3 . connect (db_path ) as conn :
155+ with SQLiteStorage . _get_connection (db_path ) as conn :
146156 cursor = conn .cursor ()
147157
148158 data = []
149159 for i , metrics in enumerate (metrics_list ):
150160 data .append (
151161 (
152162 timestamps [i ],
153- project ,
154163 run ,
155164 steps [i ],
156165 json .dumps (metrics ),
@@ -159,9 +168,9 @@ def bulk_log(
159168
160169 cursor .executemany (
161170 """
162- INSERT INTO metrics
163- (timestamp, project_name, run_name, step, metrics)
164- VALUES (?, ?, ?, ?, ? )
171+ INSERT INTO metrics
172+ (timestamp, run_name, step, metrics)
173+ VALUES (?, ?, ?, ?)
165174 """ ,
166175 data ,
167176 )
@@ -171,71 +180,85 @@ def bulk_log(
171180 def get_metrics (project : str , run : str ) -> list [dict ]:
172181 """Retrieve metrics for a specific run. The metrics also include the step count (int) and the timestamp (datetime object)."""
173182 db_path = SQLiteStorage .get_project_db_path (project )
174- if not os . path . exists (db_path ):
183+ if not db_path . exists ():
175184 return []
176185
177- with sqlite3 . connect (db_path ) as conn :
186+ with SQLiteStorage . _get_connection (db_path ) as conn :
178187 cursor = conn .cursor ()
179188 cursor .execute (
180189 """
181190 SELECT timestamp, step, metrics
182191 FROM metrics
183- WHERE project_name = ? AND run_name = ?
192+ WHERE run_name = ?
184193 ORDER BY timestamp
185194 """ ,
186- (project , run ),
195+ (run , ),
187196 )
188- rows = cursor .fetchall ()
189197
198+ rows = cursor .fetchall ()
190199 results = []
191200 for row in rows :
192- timestamp , step , metrics_json = row
193- metrics = json .loads (metrics_json )
194- metrics ["timestamp" ] = timestamp
195- metrics ["step" ] = step
201+ metrics = json .loads (row ["metrics" ])
202+ metrics ["timestamp" ] = row ["timestamp" ]
203+ metrics ["step" ] = row ["step" ]
196204 results .append (metrics )
205+
197206 return results
198207
199208 @staticmethod
200209 def get_projects () -> list [str ]:
201- """Get list of all projects by scanning database files."""
202- projects = []
203- if not os .path .exists (TRACKIO_DIR ):
204- return projects
205-
206- db_files = glob .glob (os .path .join (TRACKIO_DIR , "*.db" ))
207-
208- for db_file in db_files :
209- try :
210- with sqlite3 .connect (db_file ) as conn :
211- cursor = conn .cursor ()
212- cursor .execute (
213- "SELECT name FROM sqlite_master WHERE type='table' AND name='metrics'"
214- )
215- if cursor .fetchone ():
216- cursor .execute ("SELECT DISTINCT project_name FROM metrics" )
217- project_names = [row [0 ] for row in cursor .fetchall ()]
218- projects .extend (project_names )
219- except sqlite3 .Error :
220- continue
210+ """
211+ Get list of all projects by scanning the database files in the trackio directory.
212+ """
213+ projects : set [str ] = set ()
214+ if not TRACKIO_DIR .exists ():
215+ return []
221216
222- return list (set (projects ))
217+ for db_file in TRACKIO_DIR .glob ("*.db" ):
218+ project_name = db_file .stem
219+ projects .add (project_name )
220+ return sorted (projects )
223221
224222 @staticmethod
225223 def get_runs (project : str ) -> list [str ]:
226224 """Get list of all runs for a project."""
227225 db_path = SQLiteStorage .get_project_db_path (project )
228- if not os . path . exists (db_path ):
226+ if not db_path . exists ():
229227 return []
230228
231- with sqlite3 . connect (db_path ) as conn :
229+ with SQLiteStorage . _get_connection (db_path ) as conn :
232230 cursor = conn .cursor ()
233231 cursor .execute (
234- "SELECT DISTINCT run_name FROM metrics WHERE project_name = ?" ,
235- (project ,),
232+ "SELECT DISTINCT run_name FROM metrics" ,
236233 )
237234 return [row [0 ] for row in cursor .fetchall ()]
238235
236+ @staticmethod
237+ def get_max_steps_for_runs (project : str , runs : list [str ]) -> dict [str , int ]:
238+ """Efficiently get the maximum step for multiple runs in a single query."""
239+ db_path = SQLiteStorage .get_project_db_path (project )
240+ if not db_path .exists ():
241+ return {run : 0 for run in runs }
242+
243+ with SQLiteStorage ._get_connection (db_path ) as conn :
244+ cursor = conn .cursor ()
245+ placeholders = "," .join ("?" * len (runs ))
246+ cursor .execute (
247+ f"""
248+ SELECT run_name, MAX(step) as max_step
249+ FROM metrics
250+ WHERE run_name IN ({ placeholders } )
251+ GROUP BY run_name
252+ """ ,
253+ runs ,
254+ )
255+
256+ results = {run : 0 for run in runs } # Default to 0 for runs with no data
257+ for row in cursor .fetchall ():
258+ results [row ["run_name" ]] = row ["max_step" ]
259+
260+ return results
261+
239262 def finish (self ):
240263 """Cleanup when run is finished."""
241264 pass
0 commit comments