Skip to content

Commit 6857cbb

Browse files
abidlabsclaudegradio-pr-bot
authored
Support trackio.Table with trackio.Image columns (#328)
Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent 82f6fd4 commit 6857cbb

8 files changed

Lines changed: 337 additions & 42 deletions

File tree

.changeset/smooth-deer-tie.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": patch
3+
---
4+
5+
feat:Support trackio.Table with trackio.Image columns

docs/source/track.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ trackio.log({
6464

6565
### Logging tables
6666

67-
You can log tabular data using the [`Table`] class. This is useful for tracking results like predictions, or any structured data.
67+
You can log tabular data using the [`Table`] class. This is useful for tracking results like predictions, or any structured data. Tables can include image columns using the [`Image`] class.
6868

6969
```python
7070
import pandas as pd
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Example: Logging Tables with Images
4+
5+
This example demonstrates the capability to include trackio.Image objects
6+
in trackio.Table columns. The images will be displayed as thumbnails in the
7+
dashboard with captions as alt text.
8+
9+
Run with: python examples/table/table-with-images.py
10+
"""
11+
12+
import random
13+
14+
import numpy as np
15+
import pandas as pd
16+
17+
import trackio
18+
19+
20+
def create_sample_images():
21+
"""Create some sample images for demonstration."""
22+
images = []
23+
24+
red_square = np.full((100, 100, 3), [255, 0, 0], dtype=np.uint8)
25+
images.append(trackio.Image(red_square, caption="Red Square"))
26+
27+
blue_data = np.zeros((100, 100, 3), dtype=np.uint8)
28+
center = 50
29+
radius = 40
30+
y, x = np.ogrid[:100, :100]
31+
mask = (x - center) ** 2 + (y - center) ** 2 <= radius**2
32+
blue_data[mask] = [0, 0, 255]
33+
images.append(trackio.Image(blue_data, caption="Blue Circle"))
34+
35+
gradient = np.zeros((100, 100, 3), dtype=np.uint8)
36+
for i in range(100):
37+
gradient[i, :, 1] = int(255 * i / 100)
38+
images.append(trackio.Image(gradient, caption="Green Gradient"))
39+
40+
checkerboard = np.zeros((100, 100, 3), dtype=np.uint8)
41+
for i in range(0, 100, 20):
42+
for j in range(0, 100, 20):
43+
if (i // 20 + j // 20) % 2 == 0:
44+
checkerboard[i : i + 20, j : j + 20] = [255, 255, 255]
45+
images.append(trackio.Image(checkerboard, caption="Checkerboard"))
46+
47+
return images
48+
49+
50+
def main():
51+
trackio.init(
52+
project=f"table-with-images-demo-{random.randint(0, 1000000)}",
53+
name="sample-run",
54+
)
55+
images = create_sample_images()
56+
57+
data = {
58+
"experiment_id": [1, 2, 3, 4],
59+
"model_type": ["CNN", "ResNet", "VGG", "Custom"],
60+
"accuracy": [0.85, 0.92, 0.88, 0.95],
61+
"loss": [0.15, 0.08, 0.12, 0.05],
62+
"sample_output": images,
63+
"notes": [
64+
"Basic convolutional model",
65+
"Deep residual network",
66+
"Very deep network",
67+
"Custom architecture",
68+
],
69+
}
70+
71+
df = pd.DataFrame(data)
72+
table = trackio.Table(dataframe=df)
73+
74+
trackio.log({"experiment_results": table})
75+
76+
for step in range(10):
77+
trackio.log(
78+
{
79+
"training_loss": 1.0 * np.exp(-step * 0.1) + 0.1,
80+
"validation_accuracy": 0.5 + 0.4 * (1 - np.exp(-step * 0.15)),
81+
"learning_rate": 0.001 * (0.95**step),
82+
},
83+
step=step,
84+
)
85+
86+
mixed_data = {
87+
"test_id": [1, 2, 3, 4, 5],
88+
"test_type": ["visual", "numerical", "visual", "numerical", "visual"],
89+
"result_image": [images[0], None, images[1], None, images[2]],
90+
"score": [95.5, 87.2, 91.8, 89.1, 93.4],
91+
"passed": [True, True, True, False, True],
92+
}
93+
94+
mixed_df = pd.DataFrame(mixed_data)
95+
mixed_table = trackio.Table(dataframe=mixed_df)
96+
trackio.log({"mixed_test_results": mixed_table})
97+
98+
trackio.finish()
99+
100+
101+
if __name__ == "__main__":
102+
main()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""End-to-end test for Table with TrackioImage functionality."""
2+
3+
import pandas as pd
4+
5+
import trackio
6+
from trackio.media import TrackioImage
7+
from trackio.sqlite_storage import SQLiteStorage
8+
from trackio.table import Table
9+
10+
PROJECT_NAME = "test_table_images"
11+
12+
13+
def test_table_mixed_images_and_regular_data(image_ndarray, temp_dir):
14+
"""Test table with some rows having images and others not."""
15+
trackio.init(project=PROJECT_NAME, name="mixed_test")
16+
17+
img = TrackioImage(image_ndarray, caption="Only Image")
18+
19+
df = pd.DataFrame(
20+
{
21+
"experiment": ["exp1", "exp2", "exp3"],
22+
"result_image": [img, None, img],
23+
"score": [0.75, 0.80, 0.85],
24+
}
25+
)
26+
27+
table = Table(dataframe=df)
28+
trackio.log({"mixed_results": table})
29+
trackio.finish()
30+
31+
logs = SQLiteStorage.get_logs(PROJECT_NAME, "mixed_test")
32+
table_data = None
33+
34+
for log in logs:
35+
if "mixed_results" in log:
36+
value = log["mixed_results"]
37+
if isinstance(value, dict) and value.get("_type") == Table.TYPE:
38+
table_data = value["_value"]
39+
break
40+
41+
assert table_data is not None
42+
assert len(table_data) == 3
43+
44+
assert isinstance(table_data[0]["result_image"], dict)
45+
assert table_data[1]["result_image"] is None
46+
assert isinstance(table_data[2]["result_image"], dict)

tests/test_table.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pandas as pd
2+
3+
from trackio.media import TrackioImage
4+
from trackio.table import Table
5+
6+
PROJECT_NAME = "test_project"
7+
RUN_NAME = "test_run"
8+
9+
10+
def test_table_to_dict_with_images(image_ndarray, temp_dir):
11+
img = TrackioImage(image_ndarray, caption="Mixed Test")
12+
df = pd.DataFrame(
13+
{
14+
"step": [1, 2, 3],
15+
"image": [img, None, img],
16+
"text": ["hello", "world", "test"],
17+
"number": [1.5, 2.5, 3.5],
18+
}
19+
)
20+
21+
table = Table(dataframe=df)
22+
result = table._to_dict(project=PROJECT_NAME, run=RUN_NAME, step=5)
23+
24+
assert result["_type"] == Table.TYPE
25+
assert len(result["_value"]) == 3
26+
27+
row1 = result["_value"][0]
28+
assert row1["step"] == 1
29+
assert row1["text"] == "hello"
30+
assert row1["number"] == 1.5
31+
assert isinstance(row1["image"], dict)
32+
assert row1["image"]["_type"] == TrackioImage.TYPE
33+
34+
row2 = result["_value"][1]
35+
assert row2["step"] == 2
36+
assert row2["text"] == "world"
37+
assert row2["number"] == 2.5
38+
assert row2["image"] is None
39+
40+
row3 = result["_value"][2]
41+
assert row3["step"] == 3
42+
assert row3["text"] == "test"
43+
assert row3["number"] == 3.5
44+
assert isinstance(row3["image"], dict)
45+
assert row3["image"]["_type"] == TrackioImage.TYPE
46+
47+
48+
def test_table_to_display_format_with_images():
49+
table_data = [
50+
{
51+
"step": 1,
52+
"image": {
53+
"_type": "trackio.image",
54+
"file_path": "test/path/image.png",
55+
"caption": "Test Caption",
56+
},
57+
"value": 42,
58+
"text": "regular text",
59+
},
60+
{
61+
"step": 2,
62+
"image": None,
63+
"value": 84,
64+
"text": "more text",
65+
},
66+
]
67+
68+
processed_data = Table.to_display_format(table_data)
69+
70+
assert len(processed_data) == 2
71+
72+
row1 = processed_data[0]
73+
assert row1["step"] == 1
74+
assert row1["value"] == 42
75+
assert row1["text"] == "regular text"
76+
assert "![Test Caption](/gradio_api/file=" in row1["image"]
77+
assert row1["image"].endswith(
78+
"image.png)"
79+
) # The extra ) is due to the Markdown syntax
80+
81+
row2 = processed_data[1]
82+
assert row2["step"] == 2
83+
assert row2["value"] == 84
84+
assert row2["text"] == "more text"
85+
assert row2["image"] is None

trackio/run.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ def _get_username(self) -> str | None:
7070
def _batch_sender(self):
7171
"""Send batched logs every BATCH_SEND_INTERVAL."""
7272
while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
73-
# If the stop flag has been set, then just quickly send all
74-
# the logs and exit.
7573
if not self._stop_flag.is_set():
7674
time.sleep(BATCH_SEND_INTERVAL)
7775

@@ -112,36 +110,21 @@ def _init_client_background(self):
112110

113111
self._batch_sender()
114112

115-
def _process_media(self, metrics, step: int | None) -> dict:
113+
def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
116114
"""
117115
Serialize media in metrics and upload to space if needed.
118116
"""
119-
serializable_metrics = {}
120-
if not step:
121-
step = 0
122-
for key, value in metrics.items():
123-
if isinstance(value, TrackioMedia):
124-
value._save(self.project, self.name, step)
125-
serializable_metrics[key] = value._to_dict()
126-
if self._space_id:
127-
# Upload local media when deploying to space
128-
upload_entry: UploadEntry = {
129-
"project": self.project,
130-
"run": self.name,
131-
"step": step,
132-
"uploaded_file": handle_file(value._get_absolute_file_path()),
133-
}
134-
with self._client_lock:
135-
self._queued_uploads.append(upload_entry)
136-
else:
137-
serializable_metrics[key] = value
138-
return serializable_metrics
139-
140-
@staticmethod
141-
def _replace_tables(metrics):
142-
for k, v in metrics.items():
143-
if isinstance(v, (Table, Histogram)):
144-
metrics[k] = v._to_dict()
117+
value._save(self.project, self.name, step)
118+
if self._space_id:
119+
upload_entry: UploadEntry = {
120+
"project": self.project,
121+
"run": self.name,
122+
"step": step,
123+
"uploaded_file": handle_file(value._get_absolute_file_path()),
124+
}
125+
with self._client_lock:
126+
self._queued_uploads.append(upload_entry)
127+
return value._to_dict()
145128

146129
def log(self, metrics: dict, step: int | None = None):
147130
renamed_keys = []
@@ -159,9 +142,15 @@ def log(self, metrics: dict, step: int | None = None):
159142
warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'")
160143

161144
metrics = new_metrics
162-
Run._replace_tables(metrics)
163-
164-
metrics = self._process_media(metrics, step)
145+
for key, value in metrics.items():
146+
if isinstance(value, Table):
147+
metrics[key] = value._to_dict(
148+
project=self.project, run=self.name, step=step
149+
)
150+
elif isinstance(value, Histogram):
151+
metrics[key] = value._to_dict()
152+
elif isinstance(value, TrackioMedia):
153+
metrics[key] = self._process_media(value, step)
165154
metrics = utils.serialize_values(metrics)
166155

167156
config_to_log = None
@@ -184,7 +173,6 @@ def finish(self):
184173
"""Cleanup when run is finished."""
185174
self._stop_flag.set()
186175

187-
# Wait for the batch sender to finish before joining the client thread.
188176
time.sleep(2 * BATCH_SEND_INTERVAL)
189177

190178
if self._client_thread is not None:

0 commit comments

Comments
 (0)