Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/smooth-deer-tie.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"trackio": patch
---

feat:Support trackio.Table with trackio.Image columns
2 changes: 1 addition & 1 deletion docs/source/track.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ trackio.log({

### Logging tables

You can log tabular data using the [`Table`] class. This is useful for tracking results like predictions, or any structured data.
You can log tabular data using the [`Table`] class. This is useful for tracking results like predictions, or any structured data. Tables can include image columns using the [`Image`] class.

```python
import pandas as pd
Expand Down
102 changes: 102 additions & 0 deletions examples/table/table-with-images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python3
"""
Example: Logging Tables with Images

This example demonstrates the capability to include trackio.Image objects
in trackio.Table columns. The images will be displayed as thumbnails in the
dashboard with captions as alt text.

Run with: python examples/table/table-with-images.py
"""

import random

import numpy as np
import pandas as pd

import trackio


def create_sample_images():
"""Create some sample images for demonstration."""
images = []

red_square = np.full((100, 100, 3), [255, 0, 0], dtype=np.uint8)
images.append(trackio.Image(red_square, caption="Red Square"))

blue_data = np.zeros((100, 100, 3), dtype=np.uint8)
center = 50
radius = 40
y, x = np.ogrid[:100, :100]
mask = (x - center) ** 2 + (y - center) ** 2 <= radius**2
blue_data[mask] = [0, 0, 255]
images.append(trackio.Image(blue_data, caption="Blue Circle"))

gradient = np.zeros((100, 100, 3), dtype=np.uint8)
for i in range(100):
gradient[i, :, 1] = int(255 * i / 100)
images.append(trackio.Image(gradient, caption="Green Gradient"))

checkerboard = np.zeros((100, 100, 3), dtype=np.uint8)
for i in range(0, 100, 20):
for j in range(0, 100, 20):
if (i // 20 + j // 20) % 2 == 0:
checkerboard[i : i + 20, j : j + 20] = [255, 255, 255]
images.append(trackio.Image(checkerboard, caption="Checkerboard"))

return images


def main():
trackio.init(
project=f"table-with-images-demo-{random.randint(0, 1000000)}",
name="sample-run",
)
images = create_sample_images()

data = {
"experiment_id": [1, 2, 3, 4],
"model_type": ["CNN", "ResNet", "VGG", "Custom"],
"accuracy": [0.85, 0.92, 0.88, 0.95],
"loss": [0.15, 0.08, 0.12, 0.05],
"sample_output": images,
"notes": [
"Basic convolutional model",
"Deep residual network",
"Very deep network",
"Custom architecture",
],
}

df = pd.DataFrame(data)
table = trackio.Table(dataframe=df)

trackio.log({"experiment_results": table})

for step in range(10):
trackio.log(
{
"training_loss": 1.0 * np.exp(-step * 0.1) + 0.1,
"validation_accuracy": 0.5 + 0.4 * (1 - np.exp(-step * 0.15)),
"learning_rate": 0.001 * (0.95**step),
},
step=step,
)

mixed_data = {
"test_id": [1, 2, 3, 4, 5],
"test_type": ["visual", "numerical", "visual", "numerical", "visual"],
"result_image": [images[0], None, images[1], None, images[2]],
"score": [95.5, 87.2, 91.8, 89.1, 93.4],
"passed": [True, True, True, False, True],
}

mixed_df = pd.DataFrame(mixed_data)
mixed_table = trackio.Table(dataframe=mixed_df)
trackio.log({"mixed_test_results": mixed_table})

trackio.finish()


if __name__ == "__main__":
main()
46 changes: 46 additions & 0 deletions tests/e2e/test_table_with_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""End-to-end test for Table with TrackioImage functionality."""

import pandas as pd

import trackio
from trackio.media import TrackioImage
from trackio.sqlite_storage import SQLiteStorage
from trackio.table import Table

PROJECT_NAME = "test_table_images"


def test_table_mixed_images_and_regular_data(image_ndarray, temp_dir):
"""Test table with some rows having images and others not."""
trackio.init(project=PROJECT_NAME, name="mixed_test")

img = TrackioImage(image_ndarray, caption="Only Image")

df = pd.DataFrame(
{
"experiment": ["exp1", "exp2", "exp3"],
"result_image": [img, None, img],
"score": [0.75, 0.80, 0.85],
}
)

table = Table(dataframe=df)
trackio.log({"mixed_results": table})
trackio.finish()

logs = SQLiteStorage.get_logs(PROJECT_NAME, "mixed_test")
table_data = None

for log in logs:
if "mixed_results" in log:
value = log["mixed_results"]
if isinstance(value, dict) and value.get("_type") == Table.TYPE:
table_data = value["_value"]
break

assert table_data is not None
assert len(table_data) == 3

assert isinstance(table_data[0]["result_image"], dict)
assert table_data[1]["result_image"] is None
assert isinstance(table_data[2]["result_image"], dict)
85 changes: 85 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pandas as pd

from trackio.media import TrackioImage
from trackio.table import Table

PROJECT_NAME = "test_project"
RUN_NAME = "test_run"


def test_table_to_dict_with_images(image_ndarray, temp_dir):
img = TrackioImage(image_ndarray, caption="Mixed Test")
df = pd.DataFrame(
{
"step": [1, 2, 3],
"image": [img, None, img],
"text": ["hello", "world", "test"],
"number": [1.5, 2.5, 3.5],
}
)

table = Table(dataframe=df)
result = table._to_dict(project=PROJECT_NAME, run=RUN_NAME, step=5)

assert result["_type"] == Table.TYPE
assert len(result["_value"]) == 3

row1 = result["_value"][0]
assert row1["step"] == 1
assert row1["text"] == "hello"
assert row1["number"] == 1.5
assert isinstance(row1["image"], dict)
assert row1["image"]["_type"] == TrackioImage.TYPE

row2 = result["_value"][1]
assert row2["step"] == 2
assert row2["text"] == "world"
assert row2["number"] == 2.5
assert row2["image"] is None

row3 = result["_value"][2]
assert row3["step"] == 3
assert row3["text"] == "test"
assert row3["number"] == 3.5
assert isinstance(row3["image"], dict)
assert row3["image"]["_type"] == TrackioImage.TYPE


def test_table_to_display_format_with_images():
table_data = [
{
"step": 1,
"image": {
"_type": "trackio.image",
"file_path": "test/path/image.png",
"caption": "Test Caption",
},
"value": 42,
"text": "regular text",
},
{
"step": 2,
"image": None,
"value": 84,
"text": "more text",
},
]

processed_data = Table.to_display_format(table_data)

assert len(processed_data) == 2

row1 = processed_data[0]
assert row1["step"] == 1
assert row1["value"] == 42
assert row1["text"] == "regular text"
assert "![Test Caption](/gradio_api/file=" in row1["image"]
assert row1["image"].endswith(
"image.png)"
) # The extra ) is due to the Markdown syntax

row2 = processed_data[1]
assert row2["step"] == 2
assert row2["value"] == 84
assert row2["text"] == "more text"
assert row2["image"] is None
54 changes: 21 additions & 33 deletions trackio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def _get_username(self) -> str | None:
def _batch_sender(self):
"""Send batched logs every BATCH_SEND_INTERVAL."""
while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
# If the stop flag has been set, then just quickly send all
# the logs and exit.
if not self._stop_flag.is_set():
time.sleep(BATCH_SEND_INTERVAL)

Expand Down Expand Up @@ -112,36 +110,21 @@ def _init_client_background(self):

self._batch_sender()

def _process_media(self, metrics, step: int | None) -> dict:
def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
"""
Serialize media in metrics and upload to space if needed.
"""
serializable_metrics = {}
if not step:
step = 0
for key, value in metrics.items():
if isinstance(value, TrackioMedia):
value._save(self.project, self.name, step)
serializable_metrics[key] = value._to_dict()
if self._space_id:
# Upload local media when deploying to space
upload_entry: UploadEntry = {
"project": self.project,
"run": self.name,
"step": step,
"uploaded_file": handle_file(value._get_absolute_file_path()),
}
with self._client_lock:
self._queued_uploads.append(upload_entry)
else:
serializable_metrics[key] = value
return serializable_metrics

@staticmethod
def _replace_tables(metrics):
for k, v in metrics.items():
if isinstance(v, (Table, Histogram)):
metrics[k] = v._to_dict()
value._save(self.project, self.name, step)
if self._space_id:
upload_entry: UploadEntry = {
"project": self.project,
"run": self.name,
"step": step,
"uploaded_file": handle_file(value._get_absolute_file_path()),
}
with self._client_lock:
self._queued_uploads.append(upload_entry)
return value._to_dict()

def log(self, metrics: dict, step: int | None = None):
renamed_keys = []
Expand All @@ -159,9 +142,15 @@ def log(self, metrics: dict, step: int | None = None):
warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'")

metrics = new_metrics
Run._replace_tables(metrics)

metrics = self._process_media(metrics, step)
for key, value in metrics.items():
if isinstance(value, Table):
metrics[key] = value._to_dict(
project=self.project, run=self.name, step=step
)
elif isinstance(value, Histogram):
metrics[key] = value._to_dict()
elif isinstance(value, TrackioMedia):
metrics[key] = self._process_media(value, step)
metrics = utils.serialize_values(metrics)

config_to_log = None
Expand All @@ -184,7 +173,6 @@ def finish(self):
"""Cleanup when run is finished."""
self._stop_flag.set()

# Wait for the batch sender to finish before joining the client thread.
time.sleep(2 * BATCH_SEND_INTERVAL)

if self._client_thread is not None:
Expand Down
Loading