diff --git a/trackio/ui.py b/trackio/ui.py index ab87e99d6..07e48f240 100644 --- a/trackio/ui.py +++ b/trackio/ui.py @@ -10,6 +10,45 @@ from sqlite_storage import SQLiteStorage from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH +css = """ +#run-cb .wrap { + gap: 2px; +} +#run-cb .wrap label { + line-height: 1; + padding: 6px; +} +""" + +COLOR_PALETTE = [ + "#3B82F6", + "#EF4444", + "#10B981", + "#F59E0B", + "#8B5CF6", + "#EC4899", + "#06B6D4", + "#84CC16", + "#F97316", + "#6366F1", +] + + +def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]: + """Generate color mapping for runs, with transparency for original data when smoothing is enabled.""" + color_map = {} + + for i, run in enumerate(runs): + base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)] + + if smoothing: + color_map[f"{run}_smoothed"] = base_color + color_map[f"{run}_original"] = base_color + "4D" + else: + color_map[run] = base_color + + return color_map + def get_projects(request: gr.Request): storage = SQLiteStorage("", "", {}) @@ -43,21 +82,52 @@ def load_run_data(project: str | None, run: str | None, smoothing: bool): if not metrics: return None df = pd.DataFrame(metrics) + + if "step" not in df.columns: + df["step"] = range(len(df)) + if smoothing: numeric_cols = df.select_dtypes(include="number").columns numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] - df[numeric_cols] = df[numeric_cols].ewm(alpha=0.1).mean() - if "step" not in df.columns: - df["step"] = range(len(df)) - return df + + df_original = df.copy() + df_original["run"] = f"{run}_original" + df_original["data_type"] = "original" + + df_smoothed = df.copy() + df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean() + df_smoothed["run"] = f"{run}_smoothed" + df_smoothed["data_type"] = "smoothed" + + combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) + return combined_df + else: + df["run"] = run + df["data_type"] = "original" + return df -def update_runs(project): +def update_runs(project, filter_text, user_interacted_with_runs=False): if project is None: runs = [] + num_runs = 0 else: runs = get_runs(project) - return gr.Dropdown(choices=runs, value=runs) + num_runs = len(runs) + if filter_text: + runs = [r for r in runs if filter_text in r] + if not user_interacted_with_runs: + return gr.CheckboxGroup( + choices=runs, value=[runs[0]] if runs else [] + ), gr.Textbox(label=f"Runs ({num_runs})") + else: + return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})") + + +def filter_runs(project, filter_text): + runs = get_runs(project) + runs = [r for r in runs if filter_text in r] + return gr.CheckboxGroup(choices=runs, value=runs) def toggle_timer(cb_value): @@ -72,6 +142,40 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None: storage.log(metrics) +def sort_metrics_by_prefix(metrics: list[str]) -> list[str]: + """ + Sort metrics by grouping prefixes together. + Metrics without prefixes come first, then grouped by prefix. + + Example: + Input: ["train/loss", "loss", "train/acc", "val/loss"] + Output: ["loss", "train/acc", "train/loss", "val/loss"] + """ + no_prefix = [] + with_prefix = [] + + for metric in metrics: + if "/" in metric: + with_prefix.append(metric) + else: + no_prefix.append(metric) + + no_prefix.sort() + + prefix_groups = {} + for metric in with_prefix: + prefix = metric.split("/")[0] + if prefix not in prefix_groups: + prefix_groups[prefix] = [] + prefix_groups[prefix].append(metric) + + sorted_with_prefix = [] + for prefix in sorted(prefix_groups.keys()): + sorted_with_prefix.extend(sorted(prefix_groups[prefix])) + + return no_prefix + sorted_with_prefix + + def configure(request: gr.Request): if metrics := request.query_params.get("metrics"): return metrics.split(",") @@ -79,20 +183,24 @@ def configure(request: gr.Request): return [] -with gr.Blocks(theme="citrus", title="Trackio Dashboard") as demo: +with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo: with gr.Sidebar() as sidebar: gr.Markdown( f"
Trackio
" ) - project_dd = gr.Dropdown(label="Project", allow_custom_value=True) + project_dd = gr.Dropdown(label="Project") + run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...") + run_cb = gr.CheckboxGroup( + label="Runs", choices=[], interactive=True, elem_id="run-cb" + ) + with gr.Sidebar(position="right", open=False) as settings_sidebar: gr.Markdown("### ⚙️ Settings") realtime_cb = gr.Checkbox(label="Refresh realtime", value=True) smoothing_cb = gr.Checkbox(label="Smoothing", value=True) - with gr.Row(): - run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True) timer = gr.Timer(value=1) metrics_subset = gr.State([]) + user_interacted_with_run_cb = gr.State(False) gr.on( [demo.load], @@ -100,24 +208,41 @@ def configure(request: gr.Request): outputs=metrics_subset, ) gr.on( - [demo.load, timer.tick], + [demo.load], fn=get_projects, outputs=project_dd, show_progress="hidden", ) gr.on( - [demo.load, project_dd.change, timer.tick], + [timer.tick], + fn=update_runs, + inputs=[project_dd, run_tb, user_interacted_with_run_cb], + outputs=[run_cb, run_tb], + show_progress="hidden", + ) + gr.on( + [demo.load, project_dd.change], fn=update_runs, - inputs=project_dd, - outputs=run_dd, + inputs=[project_dd, run_tb], + outputs=[run_cb, run_tb], show_progress="hidden", ) + realtime_cb.change( fn=toggle_timer, inputs=realtime_cb, outputs=timer, api_name="toggle_timer", ) + run_cb.input( + fn=lambda: True, + outputs=user_interacted_with_run_cb, + ) + run_tb.input( + fn=filter_runs, + inputs=[project_dd, run_tb], + outputs=run_cb, + ) gr.api( fn=log, @@ -132,48 +257,68 @@ def update_x_lim(select_data: gr.SelectData): @gr.render( triggers=[ demo.load, - run_dd.change, + run_cb.change, timer.tick, smoothing_cb.change, x_lim.change, ], - inputs=[project_dd, run_dd, smoothing_cb, metrics_subset, x_lim], + inputs=[project_dd, run_cb, smoothing_cb, metrics_subset, x_lim], ) def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value): dfs = [] + original_runs = runs.copy() + for run in runs: df = load_run_data(project, run, smoothing) if df is not None: - df["run"] = run dfs.append(df) + if dfs: master_df = pd.concat(dfs, ignore_index=True) else: master_df = pd.DataFrame() + + if master_df.empty: + return + numeric_cols = master_df.select_dtypes(include="number").columns - numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] + numeric_cols = [ + c for c in numeric_cols if c not in RESERVED_KEYS and c != "step" + ] if metrics_subset: numeric_cols = [c for c in numeric_cols if c in metrics_subset] + numeric_cols = sort_metrics_by_prefix(list(numeric_cols)) + + color_map = get_color_mapping(original_runs, smoothing) + plots: list[gr.LinePlot] = [] - for col in range(len(numeric_cols) // 2): + for col in range((len(numeric_cols) + 1) // 2): with gr.Row(key=f"row-{col}"): for i in range(2): - plot = gr.LinePlot( - master_df, - x="step", - y=numeric_cols[2 * col + i], - color="run" if "run" in master_df.columns else None, - title=numeric_cols[2 * col + i], - key=f"plot-{col}-{i}", - preserved_by_key=None, - x_lim=x_lim_value, - y_lim=[ - min(master_df[numeric_cols[2 * col + i]]), - max(master_df[numeric_cols[2 * col + i]]), - ], - show_fullscreen_button=True, - ) - plots.append(plot) + metric_idx = 2 * col + i + if metric_idx < len(numeric_cols): + metric_name = numeric_cols[metric_idx] + + metric_df = master_df.dropna(subset=[metric_name]) + + if not metric_df.empty: + plot = gr.LinePlot( + metric_df, + x="step", + y=metric_name, + color="run" if "run" in metric_df.columns else None, + color_map=color_map, + title=metric_name, + key=f"plot-{col}-{i}", + preserved_by_key=None, + x_lim=x_lim_value, + y_lim=[ + metric_df[metric_name].min(), + metric_df[metric_name].max(), + ], + show_fullscreen_button=True, + ) + plots.append(plot) for plot in plots: plot.select(update_x_lim, outputs=x_lim) @@ -181,4 +326,4 @@ def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value): if __name__ == "__main__": - demo.launch(allowed_paths=[TRACKIO_LOGO_PATH]) + demo.launch(allowed_paths=[TRACKIO_LOGO_PATH], show_api=False) diff --git a/trackio/utils.py b/trackio/utils.py index 1c0aae9a6..aca9d83f4 100644 --- a/trackio/utils.py +++ b/trackio/utils.py @@ -6,7 +6,7 @@ from huggingface_hub.constants import HF_HOME -RESERVED_KEYS = ["project", "run", "timestamp"] +RESERVED_KEYS = ["project", "run", "timestamp", "step"] TRACKIO_DIR = os.path.join(HF_HOME, "trackio") TRACKIO_LOGO_PATH = str(Path(__file__).parent.joinpath("trackio_logo.png")) diff --git a/trackio/version.txt b/trackio/version.txt index d169b2f2d..c5d54ec32 100644 --- a/trackio/version.txt +++ b/trackio/version.txt @@ -1 +1 @@ -0.0.8 +0.0.9