Skip to content

Commit c8a384d

Browse files
Fix pytests that were failling locally on MacOS (#407)
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent c7aa38d commit c8a384d

5 files changed

Lines changed: 55 additions & 22 deletions

File tree

.changeset/small-camels-stare.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": patch
3+
---
4+
5+
feat:Fix pytests that were failling locally on MacOS

tests/unit/test_run.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import time
2+
from datetime import datetime
23
from unittest.mock import MagicMock, patch
34

45
import pytest
56

67
from trackio import Run, init
7-
from trackio.utils import _cached_whoami
8+
from trackio.sqlite_storage import SQLiteStorage
89

910

1011
class DummyClient:
@@ -39,6 +40,7 @@ def test_init_resume_modes(temp_dir):
3940
assert run.name == "new-run"
4041

4142
run.log({"x": 1})
43+
SQLiteStorage.bulk_log("test-project", "new-run", [{"x": 1}])
4244
run.finish()
4345

4446
run = init(
@@ -84,10 +86,10 @@ def test_init_resume_modes(temp_dir):
8486
assert run.name == "nonexistent-run"
8587

8688

87-
@patch("huggingface_hub.whoami")
89+
@patch("trackio.utils._cached_whoami")
8890
@patch("time.time")
89-
def test_run_name_generation_with_space_id(mock_time, mock_whoami, temp_dir):
90-
mock_whoami.return_value = {"name": "testuser"}
91+
def test_run_name_generation_with_space_id(mock_time, mock_cached_whoami, temp_dir):
92+
mock_cached_whoami.return_value = {"name": "testuser"}
9193
mock_time.return_value = 1234567890
9294

9395
client = DummyClient()
@@ -100,8 +102,6 @@ def test_run_name_generation_with_space_id(mock_time, mock_whoami, temp_dir):
100102
)
101103
assert run.name == "testuser-1234567890"
102104

103-
_cached_whoami.cache_clear()
104-
105105

106106
def test_reserved_config_keys_rejected(temp_dir):
107107
with pytest.raises(ValueError, match="Config key '_test' is reserved"):
@@ -113,9 +113,9 @@ def test_reserved_config_keys_rejected(temp_dir):
113113
)
114114

115115

116-
@patch("huggingface_hub.whoami")
117-
def test_automatic_username_and_timestamp_added(mock_whoami, temp_dir):
118-
mock_whoami.return_value = {"name": "testuser"}
116+
@patch("trackio.utils._cached_whoami")
117+
def test_automatic_username_and_timestamp_added(mock_cached_whoami, temp_dir):
118+
mock_cached_whoami.return_value = {"name": "testuser"}
119119

120120
run = Run(
121121
url="http://test",
@@ -128,13 +128,9 @@ def test_automatic_username_and_timestamp_added(mock_whoami, temp_dir):
128128
assert "_Created" in run.config
129129
assert run.config["learning_rate"] == 0.01
130130

131-
from datetime import datetime
132-
133131
created_time = datetime.fromisoformat(run.config["_Created"])
134132
assert created_time.tzinfo is not None
135133

136-
_cached_whoami.cache_clear()
137-
138134

139135
def test_run_group_added(temp_dir):
140136
run = Run(

tests/unit/test_sqlite_storage.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
import sqlite3
77
import tempfile
88
import time
9+
from pathlib import Path
910

1011
import orjson
1112
import pytest
1213

14+
import trackio.sqlite_storage
15+
import trackio.utils
1316
from trackio.sqlite_storage import SQLiteStorage
1417

1518

@@ -94,12 +97,21 @@ def test_import_export(temp_dir):
9497

9598

9699
def _worker_using_sqlite_storage(
97-
project, worker_id, duration_seconds=2, sync_start_time=None
100+
project, worker_id, duration_seconds=2, sync_start_time=None, temp_dir=None
98101
):
99102
"""
100103
Worker that uses SQLiteStorage methods for database access.
101104
This will be protected by ProcessLock when available.
102105
"""
106+
if temp_dir:
107+
os.environ["TRACKIO_DIR"] = temp_dir
108+
from pathlib import Path
109+
110+
import trackio.sqlite_storage
111+
import trackio.utils
112+
113+
trackio.utils.TRACKIO_DIR = Path(temp_dir)
114+
trackio.sqlite_storage.TRACKIO_DIR = Path(temp_dir)
103115

104116
def aggressive_get_connection(db_path):
105117
conn = sqlite3.connect(str(db_path), timeout=0.01)
@@ -144,23 +156,24 @@ def aggressive_get_connection(db_path):
144156
def test_concurrent_database_access_without_errors():
145157
"""
146158
Test that concurrent database access doesn't produce 'database is locked' errors.
147-
This test should fail on main (without ProcessLock) and pass with ProcessLock fix.
148159
"""
149160
with tempfile.TemporaryDirectory() as temp_dir:
150161
os.environ["TRACKIO_DIR"] = str(temp_dir)
162+
trackio.utils.TRACKIO_DIR = Path(temp_dir)
163+
trackio.sqlite_storage.TRACKIO_DIR = Path(temp_dir)
164+
151165
project = "concurrent_test"
152166

153167
num_processes = 8
154168
duration = 2
155169

156-
# Synchronized start time (0.5s from now) to make all processes hit db simultaneously
157170
sync_start_time = time.time() + 0.5
158171

159172
with multiprocessing.Pool(processes=num_processes) as pool:
160173
results = [
161174
pool.apply_async(
162175
_worker_using_sqlite_storage,
163-
(project, i, duration, sync_start_time),
176+
(project, i, duration, sync_start_time, temp_dir),
164177
)
165178
for i in range(num_processes)
166179
]

trackio/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,6 @@ def init(
218218
utils.embed_url_in_notebook(space_url)
219219
context_vars.current_project.set(project)
220220

221-
client = None
222-
if not space_id:
223-
client = Client(url, verbose=False)
224-
225221
if resume == "must":
226222
if name is None:
227223
raise ValueError("Must provide a run name when resume='must'")
@@ -250,7 +246,7 @@ def init(
250246
run = Run(
251247
url=url,
252248
project=project,
253-
client=client,
249+
client=None,
254250
name=name,
255251
group=group,
256252
config=config,

trackio/run.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@ def __init__(
3131
auto_log_gpu: bool = False,
3232
gpu_log_interval: float = 10.0,
3333
):
34+
"""
35+
Initialize a Run for logging metrics to Trackio.
36+
37+
Args:
38+
url: The URL of the Trackio server (local Gradio app or HF Space).
39+
project: The name of the project to log metrics to.
40+
client: A pre-configured gradio_client.Client instance, or None to
41+
create one automatically in a background thread with retry logic.
42+
Passing None is recommended for normal usage. Passing a client
43+
is useful for testing (e.g., injecting a mock client).
44+
name: The name of this run. If None, a readable name like
45+
"brave-sunset-0" is auto-generated. If space_id is provided,
46+
generates a "username-timestamp" format instead.
47+
group: Optional group name to organize related runs together.
48+
config: A dictionary of configuration/hyperparameters for this run.
49+
Keys starting with '_' are reserved for internal use.
50+
space_id: The HF Space ID if logging to a Space (e.g., "user/space").
51+
If provided, media files will be uploaded to the Space.
52+
auto_log_gpu: Whether to automatically log GPU metrics (utilization,
53+
memory, temperature) at regular intervals.
54+
gpu_log_interval: The interval in seconds between GPU metric logs.
55+
Only used when auto_log_gpu is True.
56+
"""
3457
self.url = url
3558
self.project = project
3659
self._client_lock = threading.Lock()

0 commit comments

Comments
 (0)