Skip to content

Commit 5f9f51d

Browse files
Support a list of Trackio.Image in a trackio.Table cell (#336)
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent 5f32c49 commit 5f9f51d

6 files changed

Lines changed: 171 additions & 21 deletions

File tree

.changeset/eight-keys-help.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": patch
3+
---
4+
5+
feat:Support a list of `Trackio.Image` in a `trackio.Table` cell
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import random
2+
3+
import numpy as np
4+
import pandas as pd
5+
from PIL import Image
6+
7+
import trackio
8+
9+
EPOCHS = 20
10+
PROJECT_ID = random.randint(100000, 999999)
11+
12+
13+
trackio.init(
14+
project=f"deploy-images-on-spaces-{PROJECT_ID}",
15+
space_id=f"deploy-images-on-spaces-{PROJECT_ID}",
16+
)
17+
image = trackio.Image(
18+
Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
19+
)
20+
df = pd.DataFrame({"value": [0.1, 0.2, 0.3], "image": [[image, image], image, image]})
21+
table = trackio.Table(dataframe=df)
22+
trackio.log({"my_table": table})
23+
trackio.finish()

examples/deploy-on-spaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ def generate_accuracy_curve(epoch, max_epochs, max_acc=0.95, min_acc=0.1):
8484

8585
time.sleep(0.2)
8686

87-
wandb.finish()
87+
wandb.finish()

tests/test_table.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,63 @@ def test_table_to_display_format_with_images():
7373
assert row1["step"] == 1
7474
assert row1["value"] == 42
7575
assert row1["text"] == "regular text"
76-
assert "![Test Caption](/gradio_api/file=" in row1["image"]
77-
assert row1["image"].endswith(
78-
"image.png)"
79-
) # The extra ) is due to the Markdown syntax
76+
assert '<img src="/gradio_api/file=' in row1["image"]
77+
assert 'image.png"' in row1["image"]
78+
assert 'alt="Test Caption"' in row1["image"]
8079

8180
row2 = processed_data[1]
8281
assert row2["step"] == 2
8382
assert row2["value"] == 84
8483
assert row2["text"] == "more text"
8584
assert row2["image"] is None
85+
86+
87+
def test_table_to_display_format_with_multiple_images():
88+
table_data = [
89+
{
90+
"step": 1,
91+
"images": [
92+
{
93+
"_type": "trackio.image",
94+
"file_path": "test/path/image1.png",
95+
"caption": "First Image",
96+
},
97+
{
98+
"_type": "trackio.image",
99+
"file_path": "test/path/image2.png",
100+
"caption": "Second Image",
101+
},
102+
{
103+
"_type": "trackio.image",
104+
"file_path": "test/path/image3.png",
105+
"caption": "Third Image",
106+
},
107+
],
108+
"value": 42,
109+
},
110+
{
111+
"step": 2,
112+
"images": [],
113+
"value": 84,
114+
},
115+
]
116+
117+
processed_data = Table.to_display_format(table_data)
118+
119+
assert len(processed_data) == 2
120+
121+
row1 = processed_data[0]
122+
assert row1["step"] == 1
123+
assert row1["value"] == 42
124+
assert '<img src="/gradio_api/file=' in row1["images"]
125+
assert 'alt="First Image"' in row1["images"]
126+
assert 'image1.png"' in row1["images"]
127+
assert 'alt="Second Image"' in row1["images"]
128+
assert 'image2.png"' in row1["images"]
129+
assert 'alt="Third Image"' in row1["images"]
130+
assert 'image3.png"' in row1["images"]
131+
132+
row2 = processed_data[1]
133+
assert row2["step"] == 2
134+
assert row2["value"] == 84
135+
assert row2["images"] == []

trackio/run.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,61 @@ def _init_client_background(self):
110110

111111
self._batch_sender()
112112

113+
def _queue_upload(self, file_path, step: int | None):
114+
"""Queue a media file for upload to space."""
115+
upload_entry: UploadEntry = {
116+
"project": self.project,
117+
"run": self.name,
118+
"step": step,
119+
"uploaded_file": handle_file(file_path),
120+
}
121+
with self._client_lock:
122+
self._queued_uploads.append(upload_entry)
123+
113124
def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
114125
"""
115126
Serialize media in metrics and upload to space if needed.
116127
"""
117128
value._save(self.project, self.name, step)
118129
if self._space_id:
119-
upload_entry: UploadEntry = {
120-
"project": self.project,
121-
"run": self.name,
122-
"step": step,
123-
"uploaded_file": handle_file(value._get_absolute_file_path()),
124-
}
125-
with self._client_lock:
126-
self._queued_uploads.append(upload_entry)
130+
self._queue_upload(value._get_absolute_file_path(), step)
127131
return value._to_dict()
128132

133+
def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None):
134+
"""
135+
Scan a serialized table for media objects and queue them for upload to space.
136+
"""
137+
if not self._space_id:
138+
return
139+
140+
table_data = table_dict.get("_value", [])
141+
for row in table_data:
142+
for value in row.values():
143+
if isinstance(value, dict) and value.get("_type") in [
144+
"trackio.image",
145+
"trackio.video",
146+
"trackio.audio",
147+
]:
148+
file_path = value.get("file_path")
149+
if file_path:
150+
from trackio.utils import MEDIA_DIR
151+
152+
absolute_path = MEDIA_DIR / file_path
153+
self._queue_upload(absolute_path, step)
154+
elif isinstance(value, list):
155+
for item in value:
156+
if isinstance(item, dict) and item.get("_type") in [
157+
"trackio.image",
158+
"trackio.video",
159+
"trackio.audio",
160+
]:
161+
file_path = item.get("file_path")
162+
if file_path:
163+
from trackio.utils import MEDIA_DIR
164+
165+
absolute_path = MEDIA_DIR / file_path
166+
self._queue_upload(absolute_path, step)
167+
129168
def log(self, metrics: dict, step: int | None = None):
130169
renamed_keys = []
131170
new_metrics = {}
@@ -147,6 +186,7 @@ def log(self, metrics: dict, step: int | None = None):
147186
metrics[key] = value._to_dict(
148187
project=self.project, run=self.name, step=step
149188
)
189+
self._scan_and_queue_media_uploads(metrics[key], step)
150190
elif isinstance(value, Histogram):
151191
metrics[key] = value._to_dict()
152192
elif isinstance(value, TrackioMedia):

trackio/table.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313

1414
class Table:
1515
"""
16-
Initializes a Table object. Tables can include image columns using the Image class.
16+
Initializes a Table object. Tables can be used to log tabular data including images, numbers, and text.
1717
1818
Args:
1919
columns (`list[str]`, *optional*):
2020
Names of the columns in the table. Optional if `data` is provided. Not
2121
expected if `dataframe` is provided. Currently ignored.
2222
data (`list[list[Any]]`, *optional*):
23-
2D row-oriented array of values.
23+
2D row-oriented array of values. Each value can be: a number, a string (treated as Markdown and truncated if too long),
24+
or a `Trackio.Image` or list of `Trackio.Image` objects.
2425
dataframe (`pandas.`DataFrame``, *optional*):
2526
DataFrame object used to create the table. When set, `data` and `columns`
2627
arguments are ignored.
@@ -54,10 +55,20 @@ def __init__(
5455
self.data = dataframe
5556

5657
def _has_media_objects(self, dataframe: DataFrame) -> bool:
57-
"""Check if dataframe contains any TrackioMedia objects."""
58+
"""Check if dataframe contains any TrackioMedia objects or lists of TrackioMedia objects."""
5859
for col in dataframe.columns:
5960
if dataframe[col].apply(lambda x: isinstance(x, TrackioMedia)).any():
6061
return True
62+
if (
63+
dataframe[col]
64+
.apply(
65+
lambda x: isinstance(x, list)
66+
and len(x) > 0
67+
and isinstance(x[0], TrackioMedia)
68+
)
69+
.any()
70+
):
71+
return True
6172
return False
6273

6374
def _process_data(self, project: str, run: str, step: int = 0):
@@ -73,6 +84,13 @@ def _process_data(self, project: str, run: str, step: int = 0):
7384
if isinstance(value, TrackioMedia):
7485
value._save(project, run, step)
7586
processed_df.at[idx, col] = value._to_dict()
87+
if (
88+
isinstance(value, list)
89+
and len(value) > 0
90+
and isinstance(value[0], TrackioMedia)
91+
):
92+
[v._save(project, run, step) for v in value]
93+
processed_df.at[idx, col] = [v._to_dict() for v in value]
7694

7795
return processed_df.to_dict(orient="records")
7896

@@ -86,19 +104,33 @@ def to_display_format(table_data: list[dict]) -> list[dict]:
86104
table_data: List of dictionaries representing table rows (from stored _value)
87105
88106
Returns:
89-
Table data with images converted to markdown syntax
107+
Table data with images converted to markdown syntax and long text truncated.
90108
"""
91109
truncate_length = int(os.getenv("TRACKIO_TABLE_TRUNCATE_LENGTH", "250"))
110+
111+
def convert_image_to_markdown(image_data: dict) -> str:
112+
relative_path = image_data.get("file_path", "")
113+
caption = image_data.get("caption", "")
114+
absolute_path = MEDIA_DIR / relative_path
115+
return f'<img src="/gradio_api/file={absolute_path}" alt="{caption}" />'
116+
92117
processed_data = []
93118
for row in table_data:
94119
processed_row = {}
95120
for key, value in row.items():
96121
if isinstance(value, dict) and value.get("_type") == "trackio.image":
97-
relative_path = value.get("file_path", "")
98-
caption = value.get("caption", "")
99-
absolute_path = MEDIA_DIR / relative_path
122+
processed_row[key] = convert_image_to_markdown(value)
123+
elif (
124+
isinstance(value, list)
125+
and len(value) > 0
126+
and isinstance(value[0], dict)
127+
and value[0].get("_type") == "trackio.image"
128+
):
129+
# This assumes that if the first item is an image, all items are images. Ok for now since we don't support mixed types in a single cell.
100130
processed_row[key] = (
101-
f"![{caption}](/gradio_api/file={absolute_path})"
131+
'<div style="display: flex; gap: 10px;">'
132+
+ "".join([convert_image_to_markdown(item) for item in value])
133+
+ "</div>"
102134
)
103135
elif isinstance(value, str) and len(value) > truncate_length:
104136
truncated = value[:truncate_length]

0 commit comments

Comments
 (0)