Skip to content

Commit c47a3a3

Browse files
Add wandb-compatible API for trackio (#394)
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent b02046a commit c47a3a3

9 files changed

Lines changed: 501 additions & 1 deletion

File tree

.changeset/six-crabs-type.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": minor
3+
---
4+
5+
feat:Add wandb-compatible API for trackio

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
title: Deploy and Embed Dashboards
1818
- local: manage
1919
title: Manage Projects
20+
- local: python_api
21+
title: Python API for Managing Runs
2022
- local: cli_commands
2123
title: CLI Commands
2224
- local: api_mcp_server

docs/source/python_api.md

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Python API for Managing Runs
2+
3+
Trackio provides a Python API class (`trackio.Api()`) that allows you to programmatically manage runs in your projects. This API is similar to `wandb.Api()` and provides methods to delete runs, move runs between projects, and access run information.
4+
5+
**Note:** This is different from [Trackio as an API Server](api_mcp_server.md), which runs the Trackio dashboard as a web server with API endpoints. The Python API (`trackio.Api()`) is a client-side interface for managing runs in your local Trackio database.
6+
7+
## Basic Usage
8+
9+
```python
10+
import trackio
11+
12+
# Initialize the API
13+
api = trackio.Api()
14+
15+
# Get all runs in a project
16+
runs = api.runs("my_project")
17+
18+
# Access individual runs
19+
for run in runs:
20+
print(f"Run: {run.name}, Project: {run.project}")
21+
print(f"Config: {run.config}")
22+
23+
# Or access by index
24+
first_run = runs[0]
25+
```
26+
27+
## Deleting Runs
28+
29+
```python
30+
api = trackio.Api()
31+
runs = api.runs("my_project")
32+
33+
# Delete a specific run
34+
run = runs[0]
35+
success = run.delete() # Returns True if successful
36+
```
37+
38+
## Moving Runs Between Projects
39+
40+
```python
41+
api = trackio.Api()
42+
runs = api.runs("source_project")
43+
44+
# Move a run to a different project
45+
run = runs[0]
46+
success = run.move("target_project") # Returns True if successful
47+
48+
# After moving, the run object's project is updated
49+
print(run.project) # "target_project"
50+
```
51+
52+
When you move a run, all associated data is transferred:
53+
- All metrics and logs
54+
- Run configuration
55+
- System metrics
56+
- Media files (images, videos, audio)
57+
58+
The run is completely removed from the source project and added to the target project.
59+
60+
## API Reference
61+
62+
### Api
63+
64+
Main entry point for the Trackio Python API.
65+
66+
```python
67+
api = trackio.Api()
68+
```
69+
70+
#### Methods
71+
72+
- **`runs(project: str) -> Runs`**: Returns a collection of runs for the specified project. Raises `ValueError` if the project doesn't exist.
73+
74+
### Runs
75+
76+
A collection of runs that supports iteration and indexing.
77+
78+
```python
79+
runs = api.runs("my_project")
80+
len(runs) # Number of runs
81+
runs[0] # First run
82+
for run in runs: # Iterate over runs
83+
...
84+
```
85+
86+
### Run
87+
88+
Represents a single run in a project.
89+
90+
#### Properties
91+
92+
- **`id`**: The run name (same as `name`)
93+
- **`name`**: The run name
94+
- **`project`**: The project this run belongs to
95+
- **`config`**: The run's configuration dictionary (lazy-loaded)
96+
97+
#### Methods
98+
99+
- **`delete() -> bool`**: Deletes the run from its project. Returns `True` if successful, `False` otherwise.
100+
- **`move(new_project: str) -> bool`**: Moves the run to a different project. Returns `True` if successful, `False` otherwise. Updates the run's `project` property after a successful move.
101+
102+
## Examples
103+
104+
### List all runs across projects
105+
106+
```python
107+
import trackio
108+
from trackio.sqlite_storage import SQLiteStorage
109+
110+
api = trackio.Api()
111+
112+
# Get all projects
113+
projects = SQLiteStorage.get_projects()
114+
115+
# List runs in each project
116+
for project in projects:
117+
print(f"\nProject: {project}")
118+
runs = api.runs(project)
119+
for run in runs:
120+
print(f" - {run.name}")
121+
```
122+
123+
### Clean up old runs
124+
125+
```python
126+
api = trackio.Api()
127+
runs = api.runs("my_project")
128+
129+
# Delete runs older than a certain date
130+
from datetime import datetime
131+
cutoff_date = datetime(2024, 1, 1)
132+
133+
for run in runs:
134+
if run.config and "_Created" in run.config:
135+
created = datetime.fromisoformat(run.config["_Created"])
136+
if created < cutoff_date:
137+
run.delete()
138+
print(f"Deleted old run: {run.name}")
139+
```
140+
141+
### Organize runs by moving them
142+
143+
```python
144+
api = trackio.Api()
145+
146+
# Move all runs from "experiments" to "archive"
147+
source_runs = api.runs("experiments")
148+
for run in source_runs:
149+
run.move("archive")
150+
print(f"Moved {run.name} to archive")
151+
```
152+

examples/api-example.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import trackio
2+
from trackio import Api
3+
4+
project = "api_example_project"
5+
6+
for i in range(3):
7+
run_name = f"training_run_{i}"
8+
trackio.init(project=project, name=run_name)
9+
10+
for step in range(5):
11+
trackio.log(
12+
{
13+
"loss": 1.0 / (step + 1),
14+
"accuracy": 0.5 + step * 0.1,
15+
}
16+
)
17+
18+
trackio.finish()
19+
20+
api = Api()
21+
runs = api.runs(project)
22+
runs[0].delete()

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def temp_dir(monkeypatch):
2323
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
2424
for name in ["trackio.sqlite_storage"]:
2525
monkeypatch.setattr(f"{name}.TRACKIO_DIR", Path(tmpdir))
26-
for name in ["trackio.media.media", "trackio.media.utils"]:
26+
for name in ["trackio.media.media", "trackio.media.utils", "trackio.utils"]:
2727
monkeypatch.setattr(f"{name}.MEDIA_DIR", Path(tmpdir) / "media")
2828
yield tmpdir
2929

tests/e2e-local/test_api.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import trackio
2+
from trackio import Api
3+
from trackio.sqlite_storage import SQLiteStorage
4+
5+
6+
def test_delete_run(temp_dir):
7+
project = "test_delete_project"
8+
run_name = "test_delete_run"
9+
10+
trackio.init(project=project, name=run_name)
11+
trackio.log(metrics={"loss": 0.1, "accuracy": 0.9})
12+
trackio.log(metrics={"loss": 0.2, "accuracy": 0.95})
13+
trackio.finish()
14+
15+
logs = SQLiteStorage.get_logs(project=project, run=run_name)
16+
assert len(logs) == 2
17+
assert logs[0]["loss"] == 0.1
18+
assert logs[1]["loss"] == 0.2
19+
20+
api = Api()
21+
runs = api.runs(project)
22+
run = runs[0]
23+
assert run.name == run_name
24+
25+
success = run.delete()
26+
assert success is True
27+
28+
logs_after = SQLiteStorage.get_logs(project=project, run=run_name)
29+
assert len(logs_after) == 0
30+
31+
config_after = SQLiteStorage.get_run_config(project=project, run=run_name)
32+
assert config_after is None
33+
34+
runs_after = SQLiteStorage.get_runs(project=project)
35+
assert run_name not in runs_after
36+
37+
38+
def test_move_run(temp_dir, image_ndarray):
39+
source_project = "test_move_source"
40+
target_project = "test_move_target"
41+
run_name = "test_move_run"
42+
43+
trackio.init(project=source_project, name=run_name)
44+
45+
image1 = trackio.Image(image_ndarray, caption="test_image_1")
46+
image2 = trackio.Image(image_ndarray, caption="test_image_2")
47+
48+
trackio.log(metrics={"loss": 0.1, "acc": 0.9, "img1": image1})
49+
trackio.log(metrics={"loss": 0.2, "acc": 0.95, "img2": image2})
50+
trackio.finish()
51+
52+
source_logs = SQLiteStorage.get_logs(project=source_project, run=run_name)
53+
assert len(source_logs) == 2
54+
assert source_logs[0]["loss"] == 0.1
55+
assert source_logs[1]["loss"] == 0.2
56+
57+
image1_path = source_logs[0]["img1"].get("file_path")
58+
assert image1_path is not None
59+
normalized_path = str(image1_path).replace("\\", "/")
60+
assert normalized_path.startswith(f"{source_project}/{run_name}/")
61+
62+
api = Api()
63+
runs = api.runs(source_project)
64+
run = runs[0]
65+
assert run.name == run_name
66+
assert run.project == source_project
67+
68+
success = run.move(target_project)
69+
assert success is True
70+
assert run.project == target_project
71+
72+
target_logs = SQLiteStorage.get_logs(project=target_project, run=run_name)
73+
assert len(target_logs) == 2
74+
assert target_logs[0]["loss"] == 0.1
75+
assert target_logs[1]["loss"] == 0.2
76+
77+
target_image1_path = target_logs[0]["img1"].get("file_path")
78+
assert target_image1_path is not None
79+
normalized_path1 = str(target_image1_path).replace("\\", "/")
80+
assert normalized_path1.startswith(f"{target_project}/{run_name}/")
81+
82+
target_image2_path = target_logs[1]["img2"].get("file_path")
83+
assert target_image2_path is not None
84+
normalized_path2 = str(target_image2_path).replace("\\", "/")
85+
assert normalized_path2.startswith(f"{target_project}/{run_name}/")
86+
87+
source_logs_after = SQLiteStorage.get_logs(project=source_project, run=run_name)
88+
assert len(source_logs_after) == 0
89+
90+
source_runs_after = SQLiteStorage.get_runs(project=source_project)
91+
assert run_name not in source_runs_after
92+
assert len(source_runs_after) == 0
93+
94+
target_runs = SQLiteStorage.get_runs(project=target_project)
95+
assert run_name in target_runs
96+
97+
source_config_after = SQLiteStorage.get_run_config(
98+
project=source_project, run=run_name
99+
)
100+
assert source_config_after is None
101+
102+
target_config = SQLiteStorage.get_run_config(project=target_project, run=run_name)
103+
assert target_config is not None

trackio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from huggingface_hub.errors import LocalTokenNotFoundError
1616

1717
from trackio import context_vars, deploy, utils
18+
from trackio.api import Api
1819
from trackio.deploy import sync
1920
from trackio.gpu import gpu_available, log_gpu
2021
from trackio.histogram import Histogram
@@ -57,6 +58,7 @@
5758
"Audio",
5859
"Table",
5960
"Histogram",
61+
"Api",
6062
]
6163

6264
Image = TrackioImage

trackio/api.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Iterator
2+
3+
from trackio.sqlite_storage import SQLiteStorage
4+
5+
6+
class Run:
7+
def __init__(self, project: str, name: str):
8+
self.project = project
9+
self.name = name
10+
self._config = None
11+
12+
@property
13+
def id(self) -> str:
14+
return self.name
15+
16+
@property
17+
def config(self) -> dict | None:
18+
if self._config is None:
19+
self._config = SQLiteStorage.get_run_config(self.project, self.name)
20+
return self._config
21+
22+
def delete(self) -> bool:
23+
return SQLiteStorage.delete_run(self.project, self.name)
24+
25+
def move(self, new_project: str) -> bool:
26+
success = SQLiteStorage.move_run(self.project, self.name, new_project)
27+
if success:
28+
self.project = new_project
29+
return success
30+
31+
def __repr__(self) -> str:
32+
return f"<Run {self.name} in project {self.project}>"
33+
34+
35+
class Runs:
36+
def __init__(self, project: str):
37+
self.project = project
38+
self._runs = None
39+
40+
def _load_runs(self):
41+
if self._runs is None:
42+
run_names = SQLiteStorage.get_runs(self.project)
43+
self._runs = [Run(self.project, name) for name in run_names]
44+
45+
def __iter__(self) -> Iterator[Run]:
46+
self._load_runs()
47+
return iter(self._runs)
48+
49+
def __getitem__(self, index: int) -> Run:
50+
self._load_runs()
51+
return self._runs[index]
52+
53+
def __len__(self) -> int:
54+
self._load_runs()
55+
return len(self._runs)
56+
57+
def __repr__(self) -> str:
58+
self._load_runs()
59+
return f"<Runs project={self.project} count={len(self._runs)}>"
60+
61+
62+
class Api:
63+
def runs(self, project: str) -> Runs:
64+
if not SQLiteStorage.get_project_db_path(project).exists():
65+
raise ValueError(f"Project '{project}' does not exist")
66+
return Runs(project)

0 commit comments

Comments
 (0)