Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/few-bars-find.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"trackio": patch
---

feat:Scope bucket sync to trackio/ subtree to avoid walking the HF cache
96 changes: 94 additions & 2 deletions trackio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,29 @@ def _cleanup_current_run():
pass


def _safe_get_runs_for_init(project: str) -> list[str]:
def _safe_get_runs_for_init(
project: str,
space_id: str | None,
resume: str,
remote_client: RemoteClient | None = None,
check_existing_for_never: bool = False,
) -> list[str]:
if space_id is not None:
if resume == "never" and not check_existing_for_never:
return []
try:
client = remote_client or RemoteClient(
space_id,
hf_token=huggingface_hub.utils.get_token(),
verbose=False,
)
runs = client.predict(project=project, api_name="/get_runs_for_project")
return runs if isinstance(runs, list) else []
except Exception as e:
_emit_nonfatal_warning(
f"trackio.init() could not inspect existing runs for project '{project}' on Space '{space_id}': {e}. Continuing without resume metadata."
)
return []
try:
return SQLiteStorage.get_runs(project)
except Exception as e:
Expand All @@ -107,6 +129,43 @@ def _safe_get_runs_for_init(project: str) -> list[str]:
return []


def _safe_get_last_step_for_init(
project: str,
run_name: str,
space_id: str | None,
resumed: bool,
remote_client: RemoteClient | None = None,
) -> int | None:
if not resumed:
return None
if space_id is not None:
try:
client = remote_client or RemoteClient(
space_id,
hf_token=huggingface_hub.utils.get_token(),
verbose=False,
)
summary = client.predict(
project=project, run=run_name, api_name="/get_run_summary"
)
if isinstance(summary, dict):
last_step = summary.get("last_step")
return last_step if isinstance(last_step, int) else None
return None
except Exception as e:
_emit_nonfatal_warning(
f"trackio.init() could not recover the previous step for run '{run_name}' on Space '{space_id}': {e}. Continuing from step 0."
)
return None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space-backed last step lookup falls through to SQLite

Medium Severity

In _safe_get_last_step_for_init, when space_id is not None and the Space API call succeeds but returns a non-dict summary, execution falls through to the local SQLiteStorage.get_max_step_for_run path instead of returning None. The sibling function _safe_get_runs_for_init correctly handles this by always returning within the space_id is not None block (via return runs if isinstance(runs, list) else []), but here the isinstance(summary, dict) guard only returns inside the if, leaving the false branch without a return statement. This defeats the PR's goal of avoiding local SQLite/bucket access for Space-backed runs.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 8c360da. Configure here.

try:
return SQLiteStorage.get_max_step_for_run(project, run_name)
except Exception as e:
_emit_nonfatal_warning(
f"trackio.init() could not recover the previous step for run '{run_name}': {e}. Continuing from step 0."
)
return None


def init(
project: str,
name: str | None = None,
Expand Down Expand Up @@ -288,7 +347,26 @@ def init(
)
context_vars.current_project.set(project)

existing_runs = _safe_get_runs_for_init(project)
remote_client = None
if space_id is not None:
try:
remote_client = RemoteClient(
space_id,
hf_token=huggingface_hub.utils.get_token(),
verbose=False,
)
except Exception as e:
_emit_nonfatal_warning(
f"trackio.init() could not create a Space client for '{space_id}': {e}. Continuing with local fallback metadata lookups."
)

existing_runs = _safe_get_runs_for_init(
project,
space_id,
resume,
remote_client=remote_client,
check_existing_for_never=name is not None,
)

if resume == "must":
if name is None:
Expand All @@ -310,6 +388,18 @@ def init(
else:
raise ValueError("resume must be one of: 'must', 'allow', or 'never'")

initial_last_step = (
_safe_get_last_step_for_init(
project,
name,
space_id,
resumed,
remote_client=remote_client,
)
if name is not None
else None
)

if auto_log_gpu is None:
nvidia_available = gpu_available()
apple_available = apple_gpu_available()
Expand All @@ -332,6 +422,8 @@ def init(
group=group,
config=config,
space_id=space_id,
existing_runs=existing_runs,
initial_last_step=initial_last_step,
auto_log_gpu=auto_log_gpu,
gpu_log_interval=gpu_log_interval,
webhook_url=webhook_url,
Expand Down
4 changes: 2 additions & 2 deletions trackio/bucket_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def _list_bucket_file_paths(bucket_id: str, prefix: str | None = None) -> list[s
def download_bucket_to_trackio_dir(bucket_id: str) -> None:
TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
sync_bucket(
source=f"hf://buckets/{bucket_id}",
dest=str(TRACKIO_DIR.parent),
source=f"hf://buckets/{bucket_id}/trackio",
dest=str(TRACKIO_DIR),
quiet=True,
)

Expand Down
11 changes: 11 additions & 0 deletions trackio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
group: str | None = None,
config: dict | None = None,
space_id: str | None = None,
existing_runs: list[str] | None = None,
initial_last_step: int | None = None,
auto_log_gpu: bool = False,
gpu_log_interval: float = 10.0,
webhook_url: str | None = None,
Expand All @@ -65,6 +67,9 @@ def __init__(
Keys starting with '_' are reserved for internal use.
space_id: The HF Space ID if logging to a Space (e.g., "user/space").
If provided, media files will be uploaded to the Space.
existing_runs: Optional pre-fetched run names for this project. Used to
avoid redundant storage or remote lookups during init.
initial_last_step: Optional pre-fetched last step for a resumed run.
auto_log_gpu: Whether to automatically log GPU metrics (utilization,
memory, temperature) at regular intervals.
gpu_log_interval: The interval in seconds between GPU metric logs.
Expand All @@ -86,6 +91,8 @@ def __init__(
self._client_thread = None
self._client = client
self._space_id = space_id
self._existing_runs = existing_runs
self._initial_last_step = initial_last_step
if name is not None:
self.name = name
else:
Expand Down Expand Up @@ -180,6 +187,8 @@ def _warn_once(self, key: str, message: str) -> None:
_emit_nonfatal_warning(message)

def _safe_get_existing_runs(self) -> list[str]:
if self._existing_runs is not None:
return self._existing_runs
try:
return SQLiteStorage.get_runs(self.project)
except Exception as e:
Expand All @@ -190,6 +199,8 @@ def _safe_get_existing_runs(self) -> list[str]:
return []

def _safe_get_max_step_for_run(self) -> int | None:
if self._initial_last_step is not None:
return self._initial_last_step
try:
return SQLiteStorage.get_max_step_for_run(self.project, self.name)
except Exception as e:
Expand Down
Loading