Skip to content

Commit 369a7b8

Browse files
authored
Add image logging support (#142)
1 parent 8913b1f commit 369a7b8

18 files changed

Lines changed: 479 additions & 43 deletions

examples/fake-training-images.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import math
2+
import random
3+
import time
4+
5+
from PIL import Image as PILImage
6+
from PIL import ImageDraw
7+
8+
import trackio as wandb
9+
10+
EPOCHS = 12
11+
W = H = 128
12+
13+
14+
def lissajous(t, w=W, h=H):
15+
x = (w // 2) + int((w // 3) * math.sin(2.0 * t))
16+
y = (h // 2) + int((h // 3) * math.sin(3.0 * t + math.pi / 4))
17+
return x, y
18+
19+
20+
def render_overlay(target_xy, pred_xy):
21+
img = PILImage.new("RGB", (W, H), "black")
22+
draw = ImageDraw.Draw(img)
23+
24+
tx, ty = target_xy
25+
draw.ellipse((tx - 5, ty - 5, tx + 5, ty + 5), fill=(0, 255, 0))
26+
27+
px, py = pred_xy
28+
draw.ellipse((px - 5, py - 5, px + 5, py + 5), fill=(255, 80, 80))
29+
30+
draw.line([(tx, ty), (px, py)], fill=(255, 255, 0), width=1)
31+
return img
32+
33+
34+
def main():
35+
project_id = random.randint(10000, 99999)
36+
project_name = f"image-logging-demo-{project_id}"
37+
38+
for run_index in range(1, 3):
39+
run_name = f"image-run-{run_index}"
40+
wandb.init(project=project_name, name=run_name)
41+
42+
pred_x, pred_y = random.randint(0, W - 1), random.randint(0, H - 1)
43+
44+
for epoch in range(EPOCHS):
45+
t = epoch / 3.0
46+
target = lissajous(t)
47+
48+
lr = 0.35
49+
noise_scale = max(0.0, 5.0 * (1.0 - epoch / (EPOCHS - 1)))
50+
pred_x += lr * (target[0] - pred_x) + random.uniform(
51+
-noise_scale, noise_scale
52+
)
53+
pred_y += lr * (target[1] - pred_y) + random.uniform(
54+
-noise_scale, noise_scale
55+
)
56+
pred = (int(round(pred_x)), int(round(pred_y)))
57+
58+
loss = math.dist(target, pred)
59+
60+
overlay = render_overlay(target, pred)
61+
wandb.log(
62+
{
63+
"loss": loss,
64+
"target_x": target[0],
65+
"target_y": target[1],
66+
"pred_x": pred[0],
67+
"pred_y": pred[1],
68+
"overlay": wandb.Image(
69+
overlay, caption=f"step={epoch}, loss={loss:.2f}"
70+
),
71+
},
72+
step=epoch,
73+
)
74+
time.sleep(0.2)
75+
76+
wandb.finish()
77+
78+
79+
if __name__ == "__main__":
80+
main()

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ requires-python = ">=3.10"
1414
dependencies = [
1515
"pandas<3.0.0",
1616
"huggingface-hub<1.0.0",
17-
"gradio>=5.35.0,<6.0.0"
17+
"gradio>=5.43.1,<6.0.0",
18+
"numpy<3.0.0",
19+
"pillow<12.0.0",
1820
]
1921
classifiers = [
2022
"Programming Language :: Python :: 3",

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66

77
@pytest.fixture
8-
def temp_db(monkeypatch):
9-
"""Fixture that creates a temporary directory for database storage and patches the TRACKIO_DIR."""
8+
def temp_dir(monkeypatch):
9+
"""Fixture that creates a temporary TRACKIO_DIR."""
1010
with tempfile.TemporaryDirectory() as tmpdir:
11-
monkeypatch.setattr("trackio.sqlite_storage.TRACKIO_DIR", Path(tmpdir))
11+
for name in ("trackio.sqlite_storage", "trackio.media", "trackio.file_storage"):
12+
monkeypatch.setattr(f"{name}.TRACKIO_DIR", Path(tmpdir))
1213
yield tmpdir

tests/e2e/test_basic_logging.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from trackio.sqlite_storage import SQLiteStorage
33

44

5-
def test_basic_logging(temp_db):
5+
def test_basic_logging(temp_dir):
66
trackio.init(project="test_project", name="test_run")
77
trackio.log(metrics={"loss": 0.1})
88
trackio.log(metrics={"loss": 0.2, "acc": 0.9})
99
trackio.finish()
1010

11-
results = SQLiteStorage.get_metrics(project="test_project", run="test_run")
11+
results = SQLiteStorage.get_logs(project="test_project", run="test_run")
1212
assert len(results) == 2
1313
assert results[0]["loss"] == 0.1
1414
assert results[0]["step"] == 0
@@ -20,13 +20,13 @@ def test_basic_logging(temp_db):
2020
assert "timestamp" in results[1]
2121

2222

23-
def test_basic_logging_with_step(temp_db):
23+
def test_basic_logging_with_step(temp_dir):
2424
trackio.init(project="test_project", name="test_run")
2525
trackio.log(metrics={"loss": 0.1}, step=0)
2626
trackio.log(metrics={"loss": 0.2, "acc": 0.9}, step=2)
2727
trackio.finish()
2828

29-
results = SQLiteStorage.get_metrics(project="test_project", run="test_run")
29+
results = SQLiteStorage.get_logs(project="test_project", run="test_run")
3030
assert len(results) == 2
3131
assert results[0]["loss"] == 0.1
3232
assert results[0]["step"] == 0

tests/e2e/test_bulk_logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_rapid_bulk_logging():
3535
time.sleep(0.6) # Wait for the client to send the logs
3636

3737
# Verify run1 metrics
38-
metrics_run1 = SQLiteStorage.get_metrics(project_name, run1_name)
38+
metrics_run1 = SQLiteStorage.get_logs(project_name, run1_name)
3939
assert len(metrics_run1) == num_logs_run1, (
4040
f"Expected {num_logs_run1} logs for run1, but found {len(metrics_run1)}"
4141
)
@@ -52,7 +52,7 @@ def test_rapid_bulk_logging():
5252
)
5353

5454
# Verify run2 metrics
55-
metrics_run2 = SQLiteStorage.get_metrics(project_name, run2_name)
55+
metrics_run2 = SQLiteStorage.get_logs(project_name, run2_name)
5656
assert len(metrics_run2) == num_logs_run2, (
5757
f"Expected {num_logs_run2} logs for run2, but found {len(metrics_run2)}"
5858
)

tests/e2e/test_image_logging.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
3+
import trackio
4+
from trackio.sqlite_storage import SQLiteStorage
5+
6+
PROJECT_NAME = "test_project"
7+
8+
9+
def test_image_logging(temp_dir):
10+
trackio.init(project=PROJECT_NAME, name="test_run")
11+
12+
image1 = trackio.Image(
13+
np.random.randint(255, size=(100, 100, 3), dtype=np.uint8),
14+
caption="test_caption1",
15+
)
16+
image2 = trackio.Image(
17+
np.random.randint(255, size=(100, 100, 3), dtype=np.uint8),
18+
caption="test_caption2",
19+
)
20+
trackio.log(metrics={"loss": 0.1, "img1": image1})
21+
trackio.log(metrics={"loss": 0.2, "img1": image1, "img2": image2})
22+
trackio.finish()
23+
24+
metrics = SQLiteStorage.get_logs(project=PROJECT_NAME, run="test_run")
25+
26+
assert len(metrics) == 2
27+
28+
assert metrics[0]["loss"] == 0.1
29+
assert metrics[0]["step"] == 0
30+
assert metrics[0]["img1"] == image1._to_dict()
31+
32+
assert metrics[1]["loss"] == 0.2
33+
assert metrics[1]["step"] == 1
34+
assert metrics[1]["img1"] == image1._to_dict()
35+
assert metrics[1]["img2"] == image2._to_dict()

tests/e2e/test_import_from_csv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from trackio.sqlite_storage import SQLiteStorage
55

66

7-
def test_import_from_csv(temp_db):
7+
def test_import_from_csv(temp_dir):
88
trackio.import_csv(
99
csv_path=str(Path(__file__).parent / "logs.csv"),
1010
project="test_project",
1111
name="test_run",
1212
)
1313

14-
results = SQLiteStorage.get_metrics(project="test_project", run="test_run")
14+
results = SQLiteStorage.get_logs(project="test_project", run="test_run")
1515
assert len(results) == 4
1616
assert results[0]["train/loss"] == 12.2
1717
assert results[0]["train/acc"] == 82.2

tests/e2e/test_import_from_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from trackio.sqlite_storage import SQLiteStorage
77

88

9-
def test_import_from_tf_events(temp_db):
9+
def test_import_from_tf_events(temp_dir):
1010
test_run_dir = "tf_test_run"
1111

1212
def create_tfevents_tensorboardx(log_dir: Path):
@@ -28,7 +28,7 @@ def create_tfevents_tensorboardx(log_dir: Path):
2828
name="test_run",
2929
)
3030

31-
results = SQLiteStorage.get_metrics(project="test_tf_project", run="test_run_main")
31+
results = SQLiteStorage.get_logs(project="test_tf_project", run="test_run_main")
3232
# There should be 5 steps × 2 metrics = 10 entries
3333
assert len(results) == 10
3434

tests/test_media.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
5+
from trackio.media import TrackioImage
6+
7+
PROJECT_NAME = "test_project"
8+
9+
10+
def test_image_save(temp_dir):
11+
image = TrackioImage(np.random.randint(255, size=(100, 100, 3), dtype=np.uint8))
12+
image._save(PROJECT_NAME, "test_run", 0, "PNG")
13+
14+
assert image._file_format == "PNG"
15+
16+
expected_rel_dir = Path("media") / PROJECT_NAME / "test_run" / "0"
17+
assert str(image._get_relative_file_path()).startswith(str(expected_rel_dir))
18+
assert str(image._get_absolute_file_path()).endswith(".png")
19+
assert image._get_absolute_file_path().is_file()
20+
21+
22+
def test_image_serialization(temp_dir):
23+
image = TrackioImage(
24+
np.random.randint(255, size=(100, 100, 3), dtype=np.uint8),
25+
caption="test_caption",
26+
)
27+
image._save(PROJECT_NAME, "test_run", 0, "PNG")
28+
value = image._to_dict()
29+
30+
assert value is not None
31+
assert value.get("_type") == TrackioImage.TYPE
32+
assert value.get("file_path") == str(image._get_relative_file_path())
33+
assert value.get("file_format") == "PNG"
34+
assert value.get("caption") == "test_caption"
35+
36+
37+
def test_image_deserialization(temp_dir):
38+
image = TrackioImage(
39+
np.random.randint(255, size=(100, 100, 3), dtype=np.uint8),
40+
caption="test_caption",
41+
)
42+
image._save(PROJECT_NAME, "test_run", 0, "PNG")
43+
value = image._to_dict()
44+
45+
image2 = TrackioImage._from_dict(value)
46+
assert image2._get_relative_file_path() == image._get_relative_file_path()
47+
assert image2._get_absolute_file_path() == image._get_absolute_file_path()
48+
assert image2._get_absolute_file_path().is_file()
49+
assert image2._file_format == "PNG"
50+
assert image2.caption == "test_caption"

tests/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_run_log_calls_client():
2626
)
2727

2828

29-
def test_init_resume_modes(temp_db):
29+
def test_init_resume_modes(temp_dir):
3030
run = init(
3131
project="test-project",
3232
name="new-run",

0 commit comments

Comments
 (0)