Skip to content

Commit 06ea885

Browse files
abidlabsclaudegradio-pr-bot
authored
Fix SQLite corruption on bucket-mounted Spaces (#501)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent af23d74 commit 06ea885

6 files changed

Lines changed: 120 additions & 43 deletions

File tree

.changeset/rare-olives-find.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": patch
3+
---
4+
5+
feat:Fix SQLite corruption on bucket-mounted Spaces

tests/unit/test_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,13 @@ def test_trackio_dir_env_var(monkeypatch):
159159
test_path = str(tmpdir)
160160

161161
monkeypatch.setenv("TRACKIO_DIR", test_path)
162-
monkeypatch.delenv("PERSISTANT_STORAGE_ENABLED", raising=False)
163162
result_dir = utils._get_trackio_dir()
164163
assert str(result_dir) == test_path
165164

166165
monkeypatch.delenv("TRACKIO_DIR", raising=False)
167-
monkeypatch.delenv("PERSISTANT_STORAGE_ENABLED", raising=False)
168166
result_dir = utils._get_trackio_dir()
169167
assert "huggingface/trackio" in Path(result_dir).as_posix()
170168

171-
monkeypatch.delenv("TRACKIO_DIR", raising=False)
172-
monkeypatch.setenv("PERSISTANT_STORAGE_ENABLED", "true")
173-
result_dir = utils._get_trackio_dir()
174-
assert Path(result_dir).as_posix() == "/data/trackio"
175-
176-
monkeypatch.setenv("TRACKIO_DIR", test_path)
177-
monkeypatch.setenv("PERSISTANT_STORAGE_ENABLED", "true")
178-
result_dir = utils._get_trackio_dir()
179-
assert Path(result_dir).as_posix() == "/data/trackio"
180-
181169

182170
def test_plot_ordering():
183171
"""Test that TRACKIO_PLOT_ORDER environment variable correctly orders metrics."""

trackio/deploy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from trackio.utils import (
3636
MEDIA_DIR,
3737
get_or_create_project_hash,
38+
on_spaces,
3839
preprocess_space_and_dataset_ids,
3940
)
4041

@@ -157,9 +158,7 @@ def deploy_as_space(
157158
bucket_id: str | None = None,
158159
private: bool | None = None,
159160
):
160-
if (
161-
os.getenv("SYSTEM") == "spaces"
162-
): # in case a repo with this function is uploaded to spaces
161+
if on_spaces(): # in case a repo with this function is uploaded to spaces
163162
return
164163

165164
if dataset_id is not None and bucket_id is not None:
@@ -674,7 +673,7 @@ def deploy_as_static_space(
674673
private: bool | None = None,
675674
hf_token: str | None = None,
676675
) -> None:
677-
if os.getenv("SYSTEM") == "spaces":
676+
if on_spaces():
678677
return
679678

680679
hf_api = huggingface_hub.HfApi()

trackio/server.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from trackio.media import get_project_media_path
2525
from trackio.sqlite_storage import SQLiteStorage
2626
from trackio.typehints import AlertEntry, LogEntry, SystemLogEntry, UploadEntry
27+
from trackio.utils import on_spaces
2728

2829
HfApi = hf.HfApi()
2930

@@ -207,28 +208,28 @@ def oauth_hf_callback(request: Request):
207208
return RedirectResponse(url=err, status_code=302)
208209
session_id = secrets.token_urlsafe(32)
209210
_oauth_sessions[session_id] = (access_token, time.monotonic())
210-
on_spaces = os.getenv("SYSTEM") == "spaces"
211+
_on_spaces = on_spaces()
211212
resp = RedirectResponse(url=f"/?oauth_session={session_id}", status_code=302)
212213
resp.set_cookie(
213214
key="trackio_hf_access_token",
214215
value=access_token,
215216
httponly=True,
216-
samesite="none" if on_spaces else "lax",
217+
samesite="none" if _on_spaces else "lax",
217218
max_age=86400 * 30,
218219
path="/",
219-
secure=on_spaces,
220+
secure=_on_spaces,
220221
)
221222
return resp
222223

223224

224225
def oauth_logout(request: Request):
225-
on_spaces = os.getenv("SYSTEM") == "spaces"
226+
_on_spaces = on_spaces()
226227
resp = RedirectResponse(url="/", status_code=302)
227228
resp.delete_cookie(
228229
"trackio_hf_access_token",
229230
path="/",
230-
samesite="none" if on_spaces else "lax",
231-
secure=on_spaces,
231+
samesite="none" if _on_spaces else "lax",
232+
secure=_on_spaces,
232233
)
233234
return resp
234235

@@ -243,7 +244,7 @@ def check_hf_token_has_write_access(hf_token: str | None) -> None:
243244
- A cache of the whoami response for the hf_token using .whoami(token=hf_token, cache=True).
244245
- This entire function is cached using @lru_cache(maxsize=32).
245246
"""
246-
if os.getenv("SYSTEM") == "spaces":
247+
if on_spaces():
247248
if hf_token is None:
248249
raise PermissionError(
249250
"Expected a HF_TOKEN to be provided when logging to a Space"
@@ -297,7 +298,7 @@ def check_hf_token_has_write_access(hf_token: str | None) -> None:
297298

298299

299300
def check_oauth_token_has_write_access(oauth_token: str | None) -> None:
300-
if not os.getenv("SYSTEM") == "spaces":
301+
if not on_spaces():
301302
return
302303
if oauth_token is None:
303304
raise PermissionError(
@@ -343,7 +344,7 @@ def check_write_access(request: gr.Request, token: str) -> bool:
343344

344345

345346
def assert_can_mutate_runs(request: gr.Request) -> None:
346-
if os.getenv("SYSTEM") != "spaces":
347+
if not on_spaces():
347348
if check_write_access(request, write_token):
348349
return
349350
raise gr.Error(
@@ -366,7 +367,7 @@ def assert_can_mutate_runs(request: gr.Request) -> None:
366367

367368

368369
def get_run_mutation_status(request: gr.Request) -> dict[str, Any]:
369-
if os.getenv("SYSTEM") != "spaces":
370+
if not on_spaces():
370371
if check_write_access(request, write_token):
371372
return {"spaces": False, "allowed": True, "auth": "local"}
372373
return {"spaces": False, "allowed": False, "auth": "none"}

trackio/sqlite_storage.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import atexit
12
import json as json_mod
23
import os
34
import shutil
@@ -30,6 +31,7 @@
3031
TRACKIO_DIR,
3132
deserialize_values,
3233
get_color_palette,
34+
on_spaces,
3335
serialize_values,
3436
)
3537

@@ -44,24 +46,92 @@ def _configure_sqlite_pragmas(conn: sqlite3.Connection) -> None:
4446
override = os.environ.get("TRACKIO_SQLITE_JOURNAL_MODE", "").strip().lower()
4547
if override in _JOURNAL_MODE_WHITELIST:
4648
journal = override.upper()
47-
elif os.environ.get("SYSTEM") == "spaces":
49+
elif on_spaces():
4850
journal = "DELETE"
4951
else:
5052
journal = "WAL"
5153
conn.execute(f"PRAGMA journal_mode = {journal}")
5254
conn.execute("PRAGMA synchronous = NORMAL")
5355
conn.execute("PRAGMA temp_store = MEMORY")
5456
conn.execute("PRAGMA cache_size = -20000")
57+
if on_spaces():
58+
conn.execute("PRAGMA locking_mode = EXCLUSIVE")
59+
60+
61+
_persistent_connections: dict[str, sqlite3.Connection] = {}
62+
_persistent_lock = Lock()
63+
_db_access_locks: dict[str, Lock] = {}
64+
65+
66+
def _get_db_access_lock(db_path: Path) -> Lock:
67+
key = str(db_path)
68+
with _persistent_lock:
69+
if key not in _db_access_locks:
70+
_db_access_locks[key] = Lock()
71+
return _db_access_locks[key]
72+
73+
74+
def _get_or_create_persistent_conn(
75+
db_path: Path, timeout: float = 30.0
76+
) -> sqlite3.Connection:
77+
key = str(db_path)
78+
with _persistent_lock:
79+
conn = _persistent_connections.get(key)
80+
if conn is not None:
81+
try:
82+
conn.execute("SELECT 1")
83+
return conn
84+
except sqlite3.Error:
85+
try:
86+
conn.close()
87+
except sqlite3.Error:
88+
pass
89+
_persistent_connections.pop(key, None)
90+
conn = sqlite3.connect(str(db_path), timeout=timeout, check_same_thread=False)
91+
_configure_sqlite_pragmas(conn)
92+
conn.execute("SELECT 1")
93+
_persistent_connections[key] = conn
94+
return conn
95+
96+
97+
def _close_all_persistent_connections() -> None:
98+
with _persistent_lock:
99+
for conn in _persistent_connections.values():
100+
try:
101+
conn.close()
102+
except sqlite3.Error:
103+
pass
104+
_persistent_connections.clear()
105+
106+
107+
atexit.register(_close_all_persistent_connections)
55108

56109

57110
class ProcessLock:
58-
"""A file-based lock that works across processes using fcntl (Unix) or msvcrt (Windows)."""
111+
"""Lock used to coordinate database access.
112+
113+
Normally uses file-based locking for cross-process coordination. When running
114+
on a bucket-mounted filesystem where file locks are unreliable,
115+
falls back to an in-memory threading Lock (single-process only)."""
116+
117+
_thread_locks: dict[str, Lock] = {}
118+
_meta_lock = Lock()
59119

60120
def __init__(self, lockfile_path: Path):
61121
self.lockfile_path = lockfile_path
62122
self.lockfile = None
123+
self._use_thread_lock = on_spaces()
124+
if self._use_thread_lock:
125+
key = str(lockfile_path)
126+
with ProcessLock._meta_lock:
127+
if key not in ProcessLock._thread_locks:
128+
ProcessLock._thread_locks[key] = Lock()
129+
self._thread_lock = ProcessLock._thread_locks[key]
63130

64131
def __enter__(self):
132+
if self._use_thread_lock:
133+
self._thread_lock.acquire()
134+
return self
65135
if fcntl is None and _msvcrt is None:
66136
return self
67137
self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
@@ -82,6 +152,9 @@ def __enter__(self):
82152
raise IOError("Could not acquire database lock after 10 seconds")
83153

84154
def __exit__(self, exc_type, exc_val, exc_tb):
155+
if self._use_thread_lock:
156+
self._thread_lock.release()
157+
return
85158
if self.lockfile:
86159
try:
87160
if fcntl is not None:
@@ -107,16 +180,31 @@ def _get_connection(
107180
configure_pragmas: bool = True,
108181
row_factory=sqlite3.Row,
109182
) -> Iterator[sqlite3.Connection]:
110-
conn = sqlite3.connect(str(db_path), timeout=timeout)
111-
try:
112-
if configure_pragmas:
113-
_configure_sqlite_pragmas(conn)
114-
if row_factory is not None:
183+
if on_spaces():
184+
# On Spaces, all callers share a single persistent connection
185+
# that is pragma-configured at creation time. The `configure_pragmas`
186+
# flag is intentionally ignored here — the pragmas (journal mode,
187+
# synchronous, locking mode) don't affect query semantics.
188+
access_lock = _get_db_access_lock(db_path)
189+
access_lock.acquire()
190+
try:
191+
conn = _get_or_create_persistent_conn(db_path, timeout=timeout)
115192
conn.row_factory = row_factory
116-
with conn:
117-
yield conn
118-
finally:
119-
conn.close()
193+
with conn:
194+
yield conn
195+
finally:
196+
access_lock.release()
197+
else:
198+
conn = sqlite3.connect(str(db_path), timeout=timeout)
199+
try:
200+
if configure_pragmas:
201+
_configure_sqlite_pragmas(conn)
202+
if row_factory is not None:
203+
conn.row_factory = row_factory
204+
with conn:
205+
yield conn
206+
finally:
207+
conn.close()
120208

121209
@staticmethod
122210
def _get_process_lock(project: str) -> ProcessLock:

trackio/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,11 @@ def get_group_priority(group_name: str) -> tuple[int, str]:
130130
return ordered_groups, result
131131

132132

133-
def persistent_storage_enabled() -> bool:
134-
return (
135-
os.environ.get("PERSISTANT_STORAGE_ENABLED") == "true"
136-
) # typo in the name of the environment variable
133+
def on_spaces() -> bool:
134+
return os.environ.get("SYSTEM") == "spaces"
137135

138136

139137
def _get_trackio_dir() -> Path:
140-
if persistent_storage_enabled():
141-
return Path("/data/trackio")
142138
if os.environ.get("TRACKIO_DIR"):
143139
return Path(os.environ.get("TRACKIO_DIR"))
144140
return Path(HF_HOME) / "trackio"

0 commit comments

Comments
 (0)