Skip to content

Commit 373cd1c

Browse files
committed
changes
1 parent 24065b9 commit 373cd1c

6 files changed

Lines changed: 161 additions & 52 deletions

File tree

tests/e2e-spaces/test_buckets_backend.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

tests/e2e-spaces/test_metrics_on_spaces.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import time
33

44
import huggingface_hub
5+
import pytest
56
from gradio_client import Client
67

78
import trackio
9+
from trackio import utils
810

911

1012
def test_basic_logging(test_space_id):
@@ -92,3 +94,60 @@ def test_runs_data_persisted_after_restart(test_space_id):
9294
lr = cfg.get("learning_rate")
9395
assert lr is not None and abs(float(lr) - 0.001) < 1e-6
9496
assert cfg.get("epochs") == 10
97+
98+
99+
def test_bucket_space_preserves_logged_metrics_after_restart(test_space_id):
100+
_, dataset_id, bucket_id = utils.preprocess_space_and_dataset_ids(
101+
test_space_id, None, None
102+
)
103+
if dataset_id is not None or bucket_id is None:
104+
pytest.skip("Requires a Space deployed with bucket backend (no dataset_id).")
105+
106+
project_name = f"test_bucket_persist_{secrets.token_urlsafe(8)}"
107+
run_name = "metrics_run"
108+
109+
trackio.init(project=project_name, name=run_name, space_id=test_space_id)
110+
trackio.log(metrics={"loss": 0.42, "acc": 0.88})
111+
trackio.finish()
112+
113+
client = Client(test_space_id)
114+
client.predict(api_name="/force_sync")
115+
116+
huggingface_hub.add_space_variable(
117+
test_space_id, "TRACKIO_TEST_RESTART", secrets.token_urlsafe(8)
118+
)
119+
120+
time.sleep(10)
121+
deadline = time.time() + 300
122+
client = None
123+
while time.time() < deadline:
124+
try:
125+
client = Client(test_space_id, verbose=False)
126+
break
127+
except Exception:
128+
time.sleep(10)
129+
assert client is not None, "Space did not come back up after restart"
130+
131+
summary = client.predict(
132+
project=project_name, run=run_name, api_name="/get_run_summary"
133+
)
134+
assert summary["num_logs"] == 1
135+
assert "loss" in summary["metrics"] and "acc" in summary["metrics"]
136+
137+
loss_values = client.predict(
138+
project=project_name,
139+
run=run_name,
140+
metric_name="loss",
141+
api_name="/get_metric_values",
142+
)
143+
assert len(loss_values) == 1
144+
assert abs(float(loss_values[0]["value"]) - 0.42) < 1e-6
145+
146+
acc_values = client.predict(
147+
project=project_name,
148+
run=run_name,
149+
metric_name="acc",
150+
api_name="/get_metric_values",
151+
)
152+
assert len(acc_values) == 1
153+
assert abs(float(acc_values[0]["value"]) - 0.88) < 1e-6

trackio/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ tags:
77
{LINKED_HUB_METADATA}hf_oauth: true
88
hf_oauth_scopes:
99
- write-repos
10-
{BUCKET_MOUNT}---
10+
---

trackio/deploy.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import tomli as tomllib
1717

1818
import gradio
19+
import httpx
1920
import huggingface_hub
2021
from gradio_client import Client, handle_file
2122
from httpx import ReadTimeout
@@ -34,13 +35,17 @@
3435
SPACE_URL = "https://huggingface.co/spaces/{space_id}"
3536

3637

37-
def _readme_linked_hub_yaml(dataset_id: str | None, bucket_id: str | None) -> str:
38-
parts = []
38+
def _readme_linked_hub_yaml(dataset_id: str | None) -> str:
3939
if dataset_id is not None:
40-
parts.append(f"datasets:\n - {dataset_id}\n")
41-
if bucket_id is not None:
42-
parts.append(f"buckets:\n - {bucket_id}\n")
43-
return "".join(parts)
40+
return f"datasets:\n - {dataset_id}\n"
41+
return ""
42+
43+
44+
def _space_app_py_content(bucket_id: str | None) -> str:
45+
if bucket_id is None:
46+
return "import trackio\ntrackio.show()\n"
47+
path = Path(files("trackio")) / "space_bucket_app.py"
48+
return path.read_text(encoding="utf-8")
4449

4550

4651
def _retry_hf_write(op_name: str, fn, retries: int = 4, initial_delay: float = 1.5):
@@ -157,15 +162,8 @@ def deploy_as_space(
157162
readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
158163
readme_content = readme_content.replace("{APP_FILE}", "app.py")
159164
readme_content = readme_content.replace(
160-
"{LINKED_HUB_METADATA}", _readme_linked_hub_yaml(dataset_id, bucket_id)
165+
"{LINKED_HUB_METADATA}", _readme_linked_hub_yaml(dataset_id)
161166
)
162-
if bucket_id is not None:
163-
bucket_mount = (
164-
f"hf_mount:\n - src: hf://buckets/{bucket_id}\n dst: /data/trackio\n"
165-
)
166-
else:
167-
bucket_mount = ""
168-
readme_content = readme_content.replace("{BUCKET_MOUNT}", bucket_mount)
169167
readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
170168
hf_api.upload_file(
171169
path_or_fileobj=readme_buffer,
@@ -217,8 +215,7 @@ def deploy_as_space(
217215
],
218216
)
219217

220-
app_file_content = """import trackio
221-
trackio.show()"""
218+
app_file_content = _space_app_py_content(bucket_id)
222219
app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
223220
hf_api.upload_file(
224221
path_or_fileobj=app_file_buffer,
@@ -231,6 +228,7 @@ def deploy_as_space(
231228
huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
232229
if bucket_id is not None:
233230
huggingface_hub.add_space_variable(space_id, "TRACKIO_BUCKET_ID", bucket_id)
231+
huggingface_hub.add_space_variable(space_id, "TRACKIO_DIR", "/data/trackio")
234232
elif dataset_id is not None:
235233
huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
236234
if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
@@ -307,15 +305,32 @@ def _wait_until_space_running(space_id: str, timeout: int = 300) -> None:
307305
hf_api = huggingface_hub.HfApi()
308306
start = time.time()
309307
delay = 2
308+
request_timeout = 45.0
309+
failure_stages = frozenset(
310+
("NO_APP_FILE", "CONFIG_ERROR", "BUILD_ERROR", "RUNTIME_ERROR")
311+
)
310312
while time.time() - start < timeout:
311313
try:
312-
info = hf_api.space_info(space_id)
313-
if info.runtime and info.runtime.stage == "RUNNING":
314-
return
315-
except (huggingface_hub.utils.HfHubHTTPError, ReadTimeout):
314+
info = hf_api.space_info(space_id, timeout=request_timeout)
315+
if info.runtime:
316+
stage = str(info.runtime.stage)
317+
if stage in failure_stages:
318+
raise RuntimeError(
319+
f"Space {space_id} entered terminal stage {stage}. "
320+
"Fix README.md or app files; see build logs on the Hub."
321+
)
322+
if stage == "RUNNING":
323+
return
324+
except RuntimeError:
325+
raise
326+
except (huggingface_hub.utils.HfHubHTTPError, httpx.RequestError):
316327
pass
317328
time.sleep(delay)
318329
delay = min(delay * 1.5, 15)
330+
raise TimeoutError(
331+
f"Space {space_id} did not reach RUNNING within {timeout}s. "
332+
"Check status and build logs on the Hub."
333+
)
319334

320335

321336
def wait_until_space_exists(
@@ -337,7 +352,7 @@ def wait_until_space_exists(
337352
try:
338353
hf_api.space_info(space_id)
339354
return
340-
except (huggingface_hub.utils.HfHubHTTPError, ReadTimeout):
355+
except (huggingface_hub.utils.HfHubHTTPError, httpx.RequestError):
341356
time.sleep(delay)
342357
delay = min(delay * 2, 60)
343358
raise TimeoutError("Waiting for space to exist took longer than expected")
@@ -569,7 +584,7 @@ def deploy_as_static_space(
569584
else:
570585
raise ValueError(f"Failed to create Space: {e}")
571586

572-
linked = _readme_linked_hub_yaml(dataset_id, bucket_id)
587+
linked = _readme_linked_hub_yaml(dataset_id)
573588
readme_content = (
574589
"---\nsdk: static\npinned: false\ntags:\n - trackio\n"
575590
f"{linked}---\n"

trackio/space_bucket_app.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import importlib
2+
import os
3+
import platform
4+
import subprocess
5+
import sys
6+
import urllib.request
7+
8+
_DEFAULT_MOUNT = "/data/trackio"
9+
_BIN_DIR = "/tmp/trackio-hf-mount-bin"
10+
11+
12+
def _platform_release_asset():
13+
if sys.platform != "linux":
14+
raise OSError(f"hf-mount on Spaces requires Linux, got {sys.platform}")
15+
m = platform.machine().lower()
16+
if m in ("x86_64", "amd64"):
17+
return "x86_64", "linux"
18+
if m in ("aarch64", "arm64"):
19+
return "aarch64", "linux"
20+
raise OSError(f"Unsupported machine for hf-mount: {m}")
21+
22+
23+
def _download_hf_mount_binaries():
24+
arch, plat = _platform_release_asset()
25+
base = "https://github.com/huggingface/hf-mount/releases/latest/download"
26+
os.makedirs(_BIN_DIR, exist_ok=True)
27+
for name in ("hf-mount", "hf-mount-nfs", "hf-mount-fuse"):
28+
binary = f"{name}-{arch}-{plat}"
29+
url = f"{base}/{binary}"
30+
dest = os.path.join(_BIN_DIR, name)
31+
with urllib.request.urlopen(url) as response:
32+
with open(dest, "wb") as out:
33+
out.write(response.read())
34+
os.chmod(dest, 0o755)
35+
36+
37+
def start_hf_mount_for_trackio_bucket():
38+
if os.environ.get("SYSTEM") != "spaces":
39+
return
40+
bucket_id = os.environ.get("TRACKIO_BUCKET_ID")
41+
if not bucket_id:
42+
return
43+
mount_path = os.environ.get("TRACKIO_DIR", _DEFAULT_MOUNT)
44+
parent = os.path.dirname(mount_path.rstrip("/")) or "/"
45+
os.makedirs(parent, exist_ok=True)
46+
os.makedirs(mount_path, exist_ok=True)
47+
hf_mount = os.path.join(_BIN_DIR, "hf-mount")
48+
if not os.path.isfile(hf_mount):
49+
_download_hf_mount_binaries()
50+
env = {**os.environ, "PATH": _BIN_DIR + os.pathsep + os.environ.get("PATH", "")}
51+
subprocess.run(
52+
[hf_mount, "start", "bucket", bucket_id, mount_path],
53+
check=True,
54+
env=env,
55+
timeout=600,
56+
)
57+
58+
59+
start_hf_mount_for_trackio_bucket()
60+
61+
trackio = importlib.import_module("trackio")
62+
trackio.show()

trackio/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,8 @@ def _get_trackio_dir() -> Path:
134134
return Path("/data/trackio")
135135
if os.environ.get("TRACKIO_DIR"):
136136
return Path(os.environ.get("TRACKIO_DIR"))
137-
bucket_mount = Path("/data/trackio")
138-
if os.environ.get("TRACKIO_BUCKET_ID") and bucket_mount.exists():
139-
return bucket_mount
137+
if os.environ.get("TRACKIO_BUCKET_ID"):
138+
return Path("/data/trackio")
140139
return Path(HF_HOME) / "trackio"
141140

142141

0 commit comments

Comments
 (0)