diff --git a/trackio/ui.py b/trackio/ui.py
index ab87e99d6..07e48f240 100644
--- a/trackio/ui.py
+++ b/trackio/ui.py
@@ -10,6 +10,45 @@
from sqlite_storage import SQLiteStorage
from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH
+css = """
+#run-cb .wrap {
+ gap: 2px;
+}
+#run-cb .wrap label {
+ line-height: 1;
+ padding: 6px;
+}
+"""
+
+COLOR_PALETTE = [
+ "#3B82F6",
+ "#EF4444",
+ "#10B981",
+ "#F59E0B",
+ "#8B5CF6",
+ "#EC4899",
+ "#06B6D4",
+ "#84CC16",
+ "#F97316",
+ "#6366F1",
+]
+
+
+def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
+ """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
+ color_map = {}
+
+ for i, run in enumerate(runs):
+ base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
+
+ if smoothing:
+ color_map[f"{run}_smoothed"] = base_color
+ color_map[f"{run}_original"] = base_color + "4D"
+ else:
+ color_map[run] = base_color
+
+ return color_map
+
def get_projects(request: gr.Request):
storage = SQLiteStorage("", "", {})
@@ -43,21 +82,52 @@ def load_run_data(project: str | None, run: str | None, smoothing: bool):
if not metrics:
return None
df = pd.DataFrame(metrics)
+
+ if "step" not in df.columns:
+ df["step"] = range(len(df))
+
if smoothing:
numeric_cols = df.select_dtypes(include="number").columns
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
- df[numeric_cols] = df[numeric_cols].ewm(alpha=0.1).mean()
- if "step" not in df.columns:
- df["step"] = range(len(df))
- return df
+
+ df_original = df.copy()
+ df_original["run"] = f"{run}_original"
+ df_original["data_type"] = "original"
+
+ df_smoothed = df.copy()
+ df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean()
+ df_smoothed["run"] = f"{run}_smoothed"
+ df_smoothed["data_type"] = "smoothed"
+
+ combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
+ return combined_df
+ else:
+ df["run"] = run
+ df["data_type"] = "original"
+ return df
-def update_runs(project):
+def update_runs(project, filter_text, user_interacted_with_runs=False):
if project is None:
runs = []
+ num_runs = 0
else:
runs = get_runs(project)
- return gr.Dropdown(choices=runs, value=runs)
+ num_runs = len(runs)
+ if filter_text:
+ runs = [r for r in runs if filter_text in r]
+ if not user_interacted_with_runs:
+ return gr.CheckboxGroup(
+ choices=runs, value=[runs[0]] if runs else []
+ ), gr.Textbox(label=f"Runs ({num_runs})")
+ else:
+ return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
+
+
+def filter_runs(project, filter_text):
+ runs = get_runs(project)
+ runs = [r for r in runs if filter_text in r]
+ return gr.CheckboxGroup(choices=runs, value=runs)
def toggle_timer(cb_value):
@@ -72,6 +142,40 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None:
storage.log(metrics)
+def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
+ """
+ Sort metrics by grouping prefixes together.
+ Metrics without prefixes come first, then grouped by prefix.
+
+ Example:
+ Input: ["train/loss", "loss", "train/acc", "val/loss"]
+ Output: ["loss", "train/acc", "train/loss", "val/loss"]
+ """
+ no_prefix = []
+ with_prefix = []
+
+ for metric in metrics:
+ if "/" in metric:
+ with_prefix.append(metric)
+ else:
+ no_prefix.append(metric)
+
+ no_prefix.sort()
+
+ prefix_groups = {}
+ for metric in with_prefix:
+ prefix = metric.split("/")[0]
+ if prefix not in prefix_groups:
+ prefix_groups[prefix] = []
+ prefix_groups[prefix].append(metric)
+
+ sorted_with_prefix = []
+ for prefix in sorted(prefix_groups.keys()):
+ sorted_with_prefix.extend(sorted(prefix_groups[prefix]))
+
+ return no_prefix + sorted_with_prefix
+
+
def configure(request: gr.Request):
if metrics := request.query_params.get("metrics"):
return metrics.split(",")
@@ -79,20 +183,24 @@ def configure(request: gr.Request):
return []
-with gr.Blocks(theme="citrus", title="Trackio Dashboard") as demo:
+with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo:
with gr.Sidebar() as sidebar:
gr.Markdown(
f"
Trackio "
)
- project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
+ project_dd = gr.Dropdown(label="Project")
+ run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
+ run_cb = gr.CheckboxGroup(
+ label="Runs", choices=[], interactive=True, elem_id="run-cb"
+ )
+ with gr.Sidebar(position="right", open=False) as settings_sidebar:
gr.Markdown("### ⚙️ Settings")
realtime_cb = gr.Checkbox(label="Refresh realtime", value=True)
smoothing_cb = gr.Checkbox(label="Smoothing", value=True)
- with gr.Row():
- run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True)
timer = gr.Timer(value=1)
metrics_subset = gr.State([])
+ user_interacted_with_run_cb = gr.State(False)
gr.on(
[demo.load],
@@ -100,24 +208,41 @@ def configure(request: gr.Request):
outputs=metrics_subset,
)
gr.on(
- [demo.load, timer.tick],
+ [demo.load],
fn=get_projects,
outputs=project_dd,
show_progress="hidden",
)
gr.on(
- [demo.load, project_dd.change, timer.tick],
+ [timer.tick],
+ fn=update_runs,
+ inputs=[project_dd, run_tb, user_interacted_with_run_cb],
+ outputs=[run_cb, run_tb],
+ show_progress="hidden",
+ )
+ gr.on(
+ [demo.load, project_dd.change],
fn=update_runs,
- inputs=project_dd,
- outputs=run_dd,
+ inputs=[project_dd, run_tb],
+ outputs=[run_cb, run_tb],
show_progress="hidden",
)
+
realtime_cb.change(
fn=toggle_timer,
inputs=realtime_cb,
outputs=timer,
api_name="toggle_timer",
)
+ run_cb.input(
+ fn=lambda: True,
+ outputs=user_interacted_with_run_cb,
+ )
+ run_tb.input(
+ fn=filter_runs,
+ inputs=[project_dd, run_tb],
+ outputs=run_cb,
+ )
gr.api(
fn=log,
@@ -132,48 +257,68 @@ def update_x_lim(select_data: gr.SelectData):
@gr.render(
triggers=[
demo.load,
- run_dd.change,
+ run_cb.change,
timer.tick,
smoothing_cb.change,
x_lim.change,
],
- inputs=[project_dd, run_dd, smoothing_cb, metrics_subset, x_lim],
+ inputs=[project_dd, run_cb, smoothing_cb, metrics_subset, x_lim],
)
def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value):
dfs = []
+ original_runs = runs.copy()
+
for run in runs:
df = load_run_data(project, run, smoothing)
if df is not None:
- df["run"] = run
dfs.append(df)
+
if dfs:
master_df = pd.concat(dfs, ignore_index=True)
else:
master_df = pd.DataFrame()
+
+ if master_df.empty:
+ return
+
numeric_cols = master_df.select_dtypes(include="number").columns
- numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
+ numeric_cols = [
+ c for c in numeric_cols if c not in RESERVED_KEYS and c != "step"
+ ]
if metrics_subset:
numeric_cols = [c for c in numeric_cols if c in metrics_subset]
+ numeric_cols = sort_metrics_by_prefix(list(numeric_cols))
+
+ color_map = get_color_mapping(original_runs, smoothing)
+
plots: list[gr.LinePlot] = []
- for col in range(len(numeric_cols) // 2):
+ for col in range((len(numeric_cols) + 1) // 2):
with gr.Row(key=f"row-{col}"):
for i in range(2):
- plot = gr.LinePlot(
- master_df,
- x="step",
- y=numeric_cols[2 * col + i],
- color="run" if "run" in master_df.columns else None,
- title=numeric_cols[2 * col + i],
- key=f"plot-{col}-{i}",
- preserved_by_key=None,
- x_lim=x_lim_value,
- y_lim=[
- min(master_df[numeric_cols[2 * col + i]]),
- max(master_df[numeric_cols[2 * col + i]]),
- ],
- show_fullscreen_button=True,
- )
- plots.append(plot)
+ metric_idx = 2 * col + i
+ if metric_idx < len(numeric_cols):
+ metric_name = numeric_cols[metric_idx]
+
+ metric_df = master_df.dropna(subset=[metric_name])
+
+ if not metric_df.empty:
+ plot = gr.LinePlot(
+ metric_df,
+ x="step",
+ y=metric_name,
+ color="run" if "run" in metric_df.columns else None,
+ color_map=color_map,
+ title=metric_name,
+ key=f"plot-{col}-{i}",
+ preserved_by_key=None,
+ x_lim=x_lim_value,
+ y_lim=[
+ metric_df[metric_name].min(),
+ metric_df[metric_name].max(),
+ ],
+ show_fullscreen_button=True,
+ )
+ plots.append(plot)
for plot in plots:
plot.select(update_x_lim, outputs=x_lim)
@@ -181,4 +326,4 @@ def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value):
if __name__ == "__main__":
- demo.launch(allowed_paths=[TRACKIO_LOGO_PATH])
+ demo.launch(allowed_paths=[TRACKIO_LOGO_PATH], show_api=False)
diff --git a/trackio/utils.py b/trackio/utils.py
index 1c0aae9a6..aca9d83f4 100644
--- a/trackio/utils.py
+++ b/trackio/utils.py
@@ -6,7 +6,7 @@
from huggingface_hub.constants import HF_HOME
-RESERVED_KEYS = ["project", "run", "timestamp"]
+RESERVED_KEYS = ["project", "run", "timestamp", "step"]
TRACKIO_DIR = os.path.join(HF_HOME, "trackio")
TRACKIO_LOGO_PATH = str(Path(__file__).parent.joinpath("trackio_logo.png"))
diff --git a/trackio/version.txt b/trackio/version.txt
index d169b2f2d..c5d54ec32 100644
--- a/trackio/version.txt
+++ b/trackio/version.txt
@@ -1 +1 @@
-0.0.8
+0.0.9