Skip to content

Commit 1df2353

Browse files
abidlabsclaudegradio-pr-bot
authored
Add histogram support with wandb-compatible API (#309)
Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent a606b3e commit 1df2353

9 files changed

Lines changed: 286 additions & 3 deletions

File tree

.changeset/brave-suns-think.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": minor
3+
---
4+
5+
feat:Add histogram support with wandb-compatible API
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
3+
import trackio
4+
5+
run = trackio.init(project="histogram-training")
6+
7+
num_epochs = 10
8+
batch_size = 32
9+
learning_rate = 0.001
10+
11+
for epoch in range(num_epochs):
12+
epoch_losses = []
13+
epoch_weights = []
14+
15+
for batch in range(20):
16+
loss = np.random.exponential(scale=2.0) * (1 - epoch * 0.05)
17+
epoch_losses.append(loss)
18+
19+
weights = np.random.normal(0, 1 - epoch * 0.08, 1000)
20+
epoch_weights.extend(weights)
21+
22+
avg_loss = np.mean(epoch_losses)
23+
24+
trackio.log(
25+
{
26+
"loss": avg_loss,
27+
"learning_rate": learning_rate * (0.95**epoch),
28+
"loss_distribution": trackio.Histogram(epoch_losses, num_bins=20),
29+
"weight_distribution": trackio.Histogram(epoch_weights, num_bins=50),
30+
},
31+
step=epoch,
32+
)
33+
34+
gradients = np.random.laplace(0, 0.1, 5000)
35+
trackio.log({"final_gradients": trackio.Histogram(gradients, num_bins=30)})
36+
37+
trackio.finish()

tests/e2e/test_histogram_e2e.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
3+
import trackio
4+
from trackio.sqlite_storage import SQLiteStorage
5+
6+
7+
def test_histogram_with_trackio_log(temp_dir):
8+
"""Test logging histograms with trackio."""
9+
run = trackio.init(project="test_histogram")
10+
11+
data1 = np.random.randn(1000)
12+
data2 = np.random.exponential(2, 500)
13+
14+
trackio.log(
15+
{
16+
"normal_dist": trackio.Histogram(data1),
17+
"exp_dist": trackio.Histogram(data2, num_bins=30),
18+
}
19+
)
20+
21+
hist, bins = np.histogram(data1, bins=25)
22+
trackio.log({"precomputed": trackio.Histogram(np_histogram=(hist, bins))})
23+
24+
trackio.finish()
25+
26+
logs = SQLiteStorage.get_logs("test_histogram", run.name)
27+
28+
assert len(logs) == 2
29+
30+
assert "normal_dist" in logs[0]
31+
assert logs[0]["normal_dist"]["_type"] == "trackio.histogram"
32+
assert len(logs[0]["normal_dist"]["bins"]) == 65
33+
assert len(logs[0]["normal_dist"]["values"]) == 64
34+
35+
assert "exp_dist" in logs[0]
36+
assert logs[0]["exp_dist"]["_type"] == "trackio.histogram"
37+
assert len(logs[0]["exp_dist"]["bins"]) == 31
38+
assert len(logs[0]["exp_dist"]["values"]) == 30
39+
40+
assert "precomputed" in logs[1]
41+
assert logs[1]["precomputed"]["_type"] == "trackio.histogram"
42+
assert len(logs[1]["precomputed"]["bins"]) == 26
43+
assert len(logs[1]["precomputed"]["values"]) == 25

tests/test_histogram.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import numpy as np
2+
import pytest
3+
4+
from trackio.histogram import Histogram
5+
6+
7+
def test_histogram_from_sequence():
8+
"""Test creating histogram from a sequence of values."""
9+
data = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 8, 9]
10+
hist = Histogram(data)
11+
12+
assert hist.bins is not None
13+
assert hist.histogram is not None
14+
assert len(hist.bins) == 65
15+
assert len(hist.histogram) == 64
16+
assert sum(hist.histogram) == len(data)
17+
18+
19+
def test_histogram_from_np_histogram():
20+
"""Test creating histogram from pre-computed numpy histogram."""
21+
data = np.random.randn(500)
22+
np_hist, np_bins = np.histogram(data, bins=30)
23+
24+
hist = Histogram(np_histogram=(np_hist, np_bins))
25+
26+
assert np.array_equal(hist.bins, np_bins)
27+
assert np.array_equal(hist.histogram, np_hist)
28+
assert len(hist.bins) == 31
29+
assert len(hist.histogram) == 30
30+
31+
32+
def test_histogram_to_dict():
33+
"""Test histogram serialization to dictionary."""
34+
data = np.random.randn(100)
35+
hist = Histogram(data, num_bins=10)
36+
37+
hist_dict = hist._to_dict()
38+
39+
assert hist_dict["_type"] == "trackio.histogram"
40+
assert "bins" in hist_dict
41+
assert "values" in hist_dict
42+
assert isinstance(hist_dict["bins"], list)
43+
assert isinstance(hist_dict["values"], list)
44+
assert len(hist_dict["bins"]) == 11
45+
assert len(hist_dict["values"]) == 10
46+
47+
48+
def test_histogram_invalid_inputs():
49+
"""Test histogram with invalid inputs."""
50+
with pytest.raises(
51+
ValueError, match="Must provide either sequence or np_histogram"
52+
):
53+
Histogram()
54+
55+
with pytest.raises(
56+
ValueError, match="Cannot provide both sequence and np_histogram"
57+
):
58+
Histogram([1, 2, 3], np_histogram=([1, 2], [0, 1, 2]))

trackio/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import hashlib
2+
import json
23
import logging
34
import os
45
import warnings
@@ -13,6 +14,7 @@
1314
from huggingface_hub import SpaceStorage
1415

1516
from trackio import context_vars, deploy, utils
17+
from trackio.histogram import Histogram
1618
from trackio.imports import import_csv, import_tf_events
1719
from trackio.media import TrackioImage, TrackioVideo
1820
from trackio.run import Run
@@ -30,7 +32,9 @@
3032
module="gradio.helpers",
3133
)
3234

33-
__version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
35+
__version__ = json.loads(Path(__file__).parent.joinpath("package.json").read_text())[
36+
"version"
37+
]
3438

3539
__all__ = [
3640
"init",
@@ -42,6 +46,7 @@
4246
"Image",
4347
"Video",
4448
"Table",
49+
"Histogram",
4550
]
4651

4752
Image = TrackioImage

trackio/histogram.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
5+
6+
class Histogram:
7+
"""
8+
Histogram data type for Trackio, compatible with wandb.Histogram.
9+
10+
Example:
11+
```python
12+
import trackio
13+
import numpy as np
14+
15+
# Create histogram from sequence
16+
data = np.random.randn(1000)
17+
trackio.log({"distribution": trackio.Histogram(data)})
18+
19+
# Create histogram from numpy histogram
20+
hist, bins = np.histogram(data, bins=30)
21+
trackio.log({"distribution": trackio.Histogram(np_histogram=(hist, bins))})
22+
23+
# Specify custom number of bins
24+
trackio.log({"distribution": trackio.Histogram(data, num_bins=50)})
25+
```
26+
27+
Args:
28+
sequence: Optional sequence of values to create histogram from
29+
np_histogram: Optional pre-computed numpy histogram (hist, bins) tuple
30+
num_bins: Number of bins for the histogram (default 64, max 512)
31+
"""
32+
33+
TYPE = "trackio.histogram"
34+
35+
def __init__(
36+
self,
37+
sequence: Any = None,
38+
np_histogram: tuple | None = None,
39+
num_bins: int = 64,
40+
):
41+
if sequence is None and np_histogram is None:
42+
raise ValueError("Must provide either sequence or np_histogram")
43+
44+
if sequence is not None and np_histogram is not None:
45+
raise ValueError("Cannot provide both sequence and np_histogram")
46+
47+
num_bins = min(num_bins, 512)
48+
49+
if np_histogram is not None:
50+
self.histogram, self.bins = np_histogram
51+
self.histogram = np.asarray(self.histogram)
52+
self.bins = np.asarray(self.bins)
53+
else:
54+
data = np.asarray(sequence).flatten()
55+
data = data[np.isfinite(data)]
56+
if len(data) == 0:
57+
self.histogram = np.array([])
58+
self.bins = np.array([])
59+
else:
60+
self.histogram, self.bins = np.histogram(data, bins=num_bins)
61+
62+
def _to_dict(self) -> dict:
63+
"""Convert histogram to dictionary for storage."""
64+
return {
65+
"_type": self.TYPE,
66+
"bins": self.bins.tolist(),
67+
"values": self.histogram.tolist(),
68+
}

trackio/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from gradio_client import Client, handle_file
88

99
from trackio import utils
10+
from trackio.histogram import Histogram
1011
from trackio.media import TrackioMedia
1112
from trackio.sqlite_storage import SQLiteStorage
1213
from trackio.table import Table
@@ -139,7 +140,7 @@ def _process_media(self, metrics, step: int | None) -> dict:
139140
@staticmethod
140141
def _replace_tables(metrics):
141142
for k, v in metrics.items():
142-
if isinstance(v, Table):
143+
if isinstance(v, (Table, Histogram)):
143144
metrics[k] = v._to_dict()
144145

145146
def log(self, metrics: dict, step: int | None = None):

trackio/ui/main.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
try:
1515
import trackio.utils as utils
1616
from trackio.file_storage import FileStorage
17+
from trackio.histogram import Histogram
1718
from trackio.media import TrackioImage, TrackioVideo
1819
from trackio.sqlite_storage import SQLiteStorage
1920
from trackio.table import Table
@@ -25,6 +26,7 @@
2526
except ImportError:
2627
import utils
2728
from file_storage import FileStorage
29+
from histogram import Histogram
2830
from media import TrackioImage, TrackioVideo
2931
from sqlite_storage import SQLiteStorage
3032
from table import Table
@@ -1121,6 +1123,71 @@ def update_dashboard(
11211123
f"Column {metric_name} failed to render as a table: {e}"
11221124
)
11231125

1126+
# Display histograms
1127+
histogram_cols = set(master_df.columns) - {
1128+
"run",
1129+
"step",
1130+
"timestamp",
1131+
"data_type",
1132+
}
1133+
if metrics_subset:
1134+
histogram_cols = [c for c in histogram_cols if c in metrics_subset]
1135+
if metric_filter and metric_filter.strip():
1136+
histogram_cols = filter_metrics_by_regex(
1137+
list(histogram_cols), metric_filter
1138+
)
1139+
1140+
actual_histogram_count = sum(
1141+
1
1142+
for metric_name in histogram_cols
1143+
if not (metric_df := master_df.dropna(subset=[metric_name])).empty
1144+
and isinstance(value := metric_df[metric_name].iloc[-1], dict)
1145+
and value.get("_type") == Histogram.TYPE
1146+
)
1147+
1148+
if actual_histogram_count > 0:
1149+
with gr.Accordion(f"histograms ({actual_histogram_count})", open=True):
1150+
with gr.Row(key="histogram-row"):
1151+
for metric_idx, metric_name in enumerate(histogram_cols):
1152+
metric_df = master_df.dropna(subset=[metric_name])
1153+
if not metric_df.empty:
1154+
value = metric_df[metric_name].iloc[-1]
1155+
if (
1156+
isinstance(value, dict)
1157+
and "_type" in value
1158+
and value["_type"] == Histogram.TYPE
1159+
):
1160+
try:
1161+
bins = value.get("bins", [])
1162+
values = value.get("values", [])
1163+
1164+
if len(bins) > 0 and len(values) > 0:
1165+
bin_centers = [
1166+
(bins[i] + bins[i + 1]) / 2
1167+
for i in range(len(bins) - 1)
1168+
]
1169+
1170+
df = pd.DataFrame(
1171+
{"bin_center": bin_centers, "count": values}
1172+
)
1173+
1174+
gr.BarPlot(
1175+
df,
1176+
x="bin_center",
1177+
y="count",
1178+
title=f"{metric_name} (latest)",
1179+
x_title="Value",
1180+
y_title="Count",
1181+
key=f"histogram-{metric_idx}",
1182+
show_fullscreen_button=True,
1183+
min_width=400,
1184+
show_export_button=True,
1185+
)
1186+
except Exception as e:
1187+
gr.Warning(
1188+
f"Column {metric_name} failed to render as a histogram: {e}"
1189+
)
1190+
11241191
with grouped_runs_panel:
11251192

11261193
@gr.render(

trackio/version.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)