Skip to content

Commit 29e1034

Browse files
abidlabsgradio-pr-botclaude
authored
Fix static exports, Space bucket handling, and other misc issues (#517)
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent afe2959 commit 29e1034

32 files changed

Lines changed: 1225 additions & 175 deletions

.changeset/great-spiders-dance.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": minor
3+
---
4+
5+
feat:Fix static exports, Space bucket handling, and other misc issues

.github/workflows/test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ jobs:
6868
npm ci
6969
npm run build
7070
71+
- name: Run frontend unit tests
72+
if: matrix.os == 'ubuntu-latest'
73+
run: |
74+
cd trackio/frontend
75+
npm test
76+
7177
- name: Install Playwright
7278
if: matrix.os == 'ubuntu-latest'
7379
run: |

tests/e2e-local/test_api.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,20 @@ def test_local_dashboard_upload_api_accepts_only_server_uploaded_paths(temp_dir)
291291
write_headers = {"x-trackio-write-token": write_token}
292292

293293
try:
294+
blocked_upload_resp = httpx.post(
295+
f"{url.rstrip('/')}/api/upload",
296+
files={"files": (source_path.name, source_text.encode())},
297+
timeout=5,
298+
)
299+
assert blocked_upload_resp.status_code == 400
300+
assert blocked_upload_resp.json() == {
301+
"error": "A write_token is required to upload files to this server. Use the write-access URL from trackio.show(), set TRACKIO_WRITE_TOKEN, or send header X-Trackio-Write-Token."
302+
}
303+
294304
with source_path.open("rb") as handle:
295305
upload_resp = httpx.post(
296306
f"{url.rstrip('/')}/api/upload",
307+
headers=write_headers,
297308
files={"files": (source_path.name, handle)},
298309
timeout=5,
299310
)
@@ -345,6 +356,7 @@ def test_local_dashboard_upload_api_accepts_only_server_uploaded_paths(temp_dir)
345356
assert allowed_resp.status_code == 200
346357
assert allowed_target is not None
347358
assert allowed_target.read_text() == source_text
359+
assert not Path(uploaded_path).exists()
348360
assert blocked_resp.status_code == 400
349361
assert blocked_resp.json() == {
350362
"error": "Uploaded file was not created by this Trackio server."
@@ -356,6 +368,68 @@ def test_local_dashboard_upload_api_accepts_only_server_uploaded_paths(temp_dir)
356368
app.close()
357369

358370

371+
def test_local_dashboard_get_metric_values_honors_run_id(temp_dir):
372+
project = "test_metric_values_run_id"
373+
run_name = "duplicate-run"
374+
375+
first = trackio.init(project=project, name=run_name, resume="never")
376+
trackio.log(metrics={"loss": 1.0})
377+
trackio.finish()
378+
379+
second = trackio.init(project=project, name=run_name, resume="never")
380+
trackio.log(metrics={"loss": 2.0})
381+
trackio.finish()
382+
383+
app, url, _, _ = trackio.show(block_thread=False, open_browser=False)
384+
385+
try:
386+
client = Client(url, verbose=False)
387+
runs = client.predict(project, api_name="/get_runs_for_project")
388+
first_run_id = first.id
389+
second_run_id = second.id
390+
assert [run["id"] for run in runs] == [first_run_id, second_run_id]
391+
392+
latest_resp = httpx.post(
393+
f"{url.rstrip('/')}/api/get_metric_values",
394+
json={
395+
"project": project,
396+
"run": run_name,
397+
"metric_name": "loss",
398+
},
399+
timeout=5,
400+
)
401+
first_resp = httpx.post(
402+
f"{url.rstrip('/')}/api/get_metric_values",
403+
json={
404+
"project": project,
405+
"run": run_name,
406+
"run_id": first_run_id,
407+
"metric_name": "loss",
408+
},
409+
timeout=5,
410+
)
411+
second_resp = httpx.post(
412+
f"{url.rstrip('/')}/api/get_metric_values",
413+
json={
414+
"project": project,
415+
"run": run_name,
416+
"run_id": second_run_id,
417+
"metric_name": "loss",
418+
},
419+
timeout=5,
420+
)
421+
422+
assert latest_resp.status_code == 200
423+
assert first_resp.status_code == 200
424+
assert second_resp.status_code == 200
425+
assert [row["value"] for row in latest_resp.json()["data"]] == [2.0]
426+
assert [row["value"] for row in first_resp.json()["data"]] == [1.0]
427+
assert [row["value"] for row in second_resp.json()["data"]] == [2.0]
428+
finally:
429+
trackio.delete_project(project, force=True)
430+
app.close()
431+
432+
359433
def test_local_dashboard_supports_mcp(temp_dir):
360434
pytest.importorskip("mcp")
361435
from mcp import ClientSession

tests/ui/test_ui_display.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_latest_only_selects_last_run(temp_dir):
8585
latest_toggle = page.locator(".latest-toggle input[type='checkbox']")
8686
latest_toggle.check()
8787

88-
expect(checkboxes.nth(0)).not_to_be_checked()
88+
expect(checkboxes.nth(0)).to_be_checked()
8989
expect(checkboxes.nth(1)).not_to_be_checked()
90-
expect(checkboxes.nth(2)).to_be_checked()
90+
expect(checkboxes.nth(2)).not_to_be_checked()
9191

9292
browser.close()
9393
finally:

tests/unit/test_deploy.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from types import SimpleNamespace
22
from unittest.mock import patch
33

4-
from huggingface_hub import Volume
5-
64
from trackio import deploy
75
from trackio.bucket_storage import _list_bucket_file_paths
86

@@ -20,27 +18,6 @@ def test_get_space_install_requirement_includes_mcp_extra():
2018
assert requirement == f"trackio[spaces,mcp]=={deploy.trackio.__version__}"
2119

2220

23-
@patch("trackio.deploy.huggingface_hub.HfApi")
24-
def test_get_source_bucket_falls_back_to_space_info_runtime(mock_hf_api):
25-
api = mock_hf_api.return_value
26-
api.get_space_runtime.return_value = SimpleNamespace(volumes=None)
27-
api.space_info.return_value = SimpleNamespace(
28-
runtime=SimpleNamespace(
29-
volumes=[
30-
Volume(
31-
type="bucket",
32-
source="abidlabs/example-bucket",
33-
mount_path="/data",
34-
)
35-
]
36-
)
37-
)
38-
39-
bucket_id = deploy._get_source_bucket("abidlabs/example-space")
40-
41-
assert bucket_id == "abidlabs/example-bucket"
42-
43-
4421
@patch("trackio.bucket_storage.huggingface_hub.list_bucket_tree")
4522
def test_list_bucket_file_paths_uses_list_bucket_tree(mock_list_bucket_tree):
4623
mock_list_bucket_tree.return_value = [

tests/unit/test_sqlite_storage.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,38 @@ def test_get_runs_returns_chronological_order(temp_dir):
270270
assert runs == ["run-z", "run-a", "run-m"]
271271

272272

273+
def test_get_metric_values_respects_run_id_and_name_resolves_latest_run(temp_dir):
274+
project = "proj_metric_values"
275+
run_name = "dup-run"
276+
277+
SQLiteStorage.bulk_log(
278+
project,
279+
run_name,
280+
[{"loss": 1.0}],
281+
run_id="run-id-1",
282+
timestamps=["2024-01-01T00:00:00+00:00"],
283+
)
284+
SQLiteStorage.bulk_log(
285+
project,
286+
run_name,
287+
[{"loss": 2.0}],
288+
run_id="run-id-2",
289+
timestamps=["2024-01-02T00:00:00+00:00"],
290+
)
291+
292+
latest_by_name = SQLiteStorage.get_metric_values(project, run_name, "loss")
293+
first_by_id = SQLiteStorage.get_metric_values(
294+
project, run_name, "loss", run_id="run-id-1"
295+
)
296+
second_by_id = SQLiteStorage.get_metric_values(
297+
project, run_name, "loss", run_id="run-id-2"
298+
)
299+
300+
assert [row["value"] for row in latest_by_name] == [2.0]
301+
assert [row["value"] for row in first_by_id] == [1.0]
302+
assert [row["value"] for row in second_by_id] == [2.0]
303+
304+
273305
def test_rename_run(temp_dir):
274306
project = "test_project"
275307
old_name = "old_run"

trackio/__init__.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,10 @@ def init(
390390
"* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
391391
)
392392

393+
bucket_id_was_explicit = bucket_id is not None
393394
space_id, server_url = utils.resolve_space_id_and_server_url(space_id, server_url)
394-
bucket_id = bucket_id or os.environ.get("TRACKIO_BUCKET_ID")
395+
if bucket_id is None and utils.on_spaces():
396+
bucket_id = os.environ.get("TRACKIO_BUCKET_ID")
395397
if server_url is not None and not server_url.startswith(("http://", "https://")):
396398
raise ValueError(
397399
f"`server_url` must be a full URL starting with http:// or https://, got: {server_url!r}"
@@ -419,6 +421,14 @@ def init(
419421
space_id, dataset_id, bucket_id = utils.preprocess_space_and_dataset_ids(
420422
space_id, dataset_id, bucket_id
421423
)
424+
if (
425+
space_id is not None
426+
and dataset_id is None
427+
and bucket_id is not None
428+
and not bucket_id_was_explicit
429+
and not utils.on_spaces()
430+
):
431+
bucket_id = deploy.resolve_auto_bucket_id(space_id, bucket_id)
422432
except LocalTokenNotFoundError as e:
423433
raise LocalTokenNotFoundError(
424434
f"You must be logged in to Hugging Face locally when `space_id` is provided to deploy to a Space. {e}"
@@ -438,17 +448,20 @@ def init(
438448

439449
remote_source = space_id or server_base_url
440450

441-
url = context_vars.current_server.get()
442-
443451
if remote_source is not None:
444-
if url is None:
445-
url = remote_source
446-
context_vars.current_server.set(url)
447-
if space_id is not None:
448-
context_vars.current_space_id.set(space_id)
449-
context_vars.current_server_write_token.set(None)
450-
elif server_base_url is not None:
451-
context_vars.current_server_write_token.set(write_token_resolved)
452+
url = remote_source
453+
context_vars.current_server.set(url)
454+
if space_id is not None:
455+
context_vars.current_space_id.set(space_id)
456+
context_vars.current_server_write_token.set(None)
457+
else:
458+
context_vars.current_space_id.set(None)
459+
context_vars.current_server_write_token.set(write_token_resolved)
460+
else:
461+
url = None
462+
context_vars.current_server.set(None)
463+
context_vars.current_space_id.set(None)
464+
context_vars.current_server_write_token.set(None)
452465

453466
_should_embed_local = False
454467

@@ -459,13 +472,15 @@ def init(
459472
print(f"* Trackio project initialized: {project}")
460473

461474
if bucket_id is not None:
462-
os.environ["TRACKIO_BUCKET_ID"] = bucket_id
475+
if utils.on_spaces():
476+
os.environ["TRACKIO_BUCKET_ID"] = bucket_id
463477
bucket_url = f"https://huggingface.co/buckets/{bucket_id}"
464478
print(
465479
f"* Trackio metrics will be synced to Hugging Face Bucket: {bucket_url}"
466480
)
467481
elif dataset_id is not None:
468-
os.environ["TRACKIO_DATASET_ID"] = dataset_id
482+
if utils.on_spaces():
483+
os.environ["TRACKIO_DATASET_ID"] = dataset_id
469484
print(
470485
f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
471486
)

trackio/api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ def delete(self) -> bool:
3131
return SQLiteStorage.delete_run(self.project, self.name, run_id=self.id)
3232

3333
def move(self, new_project: str) -> bool:
34-
success = SQLiteStorage.move_run(self.project, self.name, new_project)
34+
success = SQLiteStorage.move_run(
35+
self.project, self.name, new_project, run_id=self.id
36+
)
3537
if success:
3638
self.project = new_project
3739
return success
3840

3941
def rename(self, new_name: str) -> "Run":
40-
SQLiteStorage.rename_run(self.project, self.name, new_name)
42+
SQLiteStorage.rename_run(self.project, self.name, new_name, run_id=self.id)
4143
self.name = new_name
4244
return self
4345

trackio/asgi_app.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import secrets
88
import tempfile
99
import threading
10+
from collections.abc import Callable
1011
from pathlib import Path
1112
from typing import Any, get_args, get_origin
1213
from urllib.parse import unquote
@@ -88,6 +89,13 @@ def consume_uploaded_temp_file(request: Request, file_data: Any) -> Path:
8889
return resolved_path
8990

9091

92+
def cleanup_uploaded_temp_file(file_path: str | Path) -> None:
93+
try:
94+
Path(file_path).unlink(missing_ok=True)
95+
except Exception:
96+
pass
97+
98+
9199
def _invoke_handler(
92100
fn: Any,
93101
request: Request,
@@ -377,6 +385,13 @@ async def sse() -> Any:
377385

378386

379387
async def upload_handler(request: Request) -> Response:
388+
upload_authorizer = getattr(request.app.state, "upload_authorizer", None)
389+
if callable(upload_authorizer):
390+
try:
391+
upload_authorizer(request)
392+
except TrackioAPIError as e:
393+
return JSONResponse({"error": str(e)}, status_code=400)
394+
380395
form = await request.form()
381396
uploads = form.getlist("files")
382397
saved_paths = []
@@ -397,11 +412,18 @@ async def gradio_upload_alias_handler(request: Request) -> Response:
397412
return await upload_handler(request)
398413

399414

415+
_DISALLOWED_FILE_SUFFIXES = frozenset(
416+
{".db", ".db-journal", ".db-wal", ".db-shm", ".sqlite", ".sqlite3"}
417+
)
418+
419+
400420
async def file_handler(request: Request) -> Response:
401421
fs_path = request.query_params.get("path")
402422
if fs_path is None:
403423
return Response("Missing path", status_code=400)
404424
fp = Path(unquote(fs_path)).resolve(strict=False)
425+
if fp.suffix.lower() in _DISALLOWED_FILE_SUFFIXES:
426+
return Response("Not found", status_code=404)
405427
allowed_roots = getattr(request.app.state, "allowed_file_roots", ())
406428
if fp.is_file() and _is_allowed_file_path(fp, allowed_roots):
407429
return FileResponse(str(fp))
@@ -415,6 +437,7 @@ def create_trackio_starlette_app(
415437
mcp_lifespan: Any = None,
416438
mcp_enabled: bool = False,
417439
allowed_file_roots: list[str | Path] | None = None,
440+
upload_authorizer: Callable[[Request], None] | None = None,
418441
) -> Starlette:
419442
routes: list[Any] = list(oauth_routes)
420443
routes.extend(
@@ -475,6 +498,7 @@ def create_trackio_starlette_app(
475498
app.state.api_registry = api_registry
476499
app.state.mcp_enabled = mcp_enabled
477500
app.state.allowed_file_roots = _normalize_allowed_file_roots(allowed_file_roots)
501+
app.state.upload_authorizer = upload_authorizer
478502
app.state.uploaded_temp_files = set()
479503
app.state.uploaded_temp_files_lock = threading.Lock()
480504
if on_spaces():

0 commit comments

Comments
 (0)