Skip to content

Commit 6857d40

Browse files
authored
Merge pull request #23 from gradio-app/ui-update
UI updates
2 parents aa533d0 + 41a85b9 commit 6857d40

3 files changed

Lines changed: 183 additions & 38 deletions

File tree

trackio/ui.py

Lines changed: 181 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,45 @@
1010
from sqlite_storage import SQLiteStorage
1111
from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH
1212

13+
css = """
14+
#run-cb .wrap {
15+
gap: 2px;
16+
}
17+
#run-cb .wrap label {
18+
line-height: 1;
19+
padding: 6px;
20+
}
21+
"""
22+
23+
COLOR_PALETTE = [
24+
"#3B82F6",
25+
"#EF4444",
26+
"#10B981",
27+
"#F59E0B",
28+
"#8B5CF6",
29+
"#EC4899",
30+
"#06B6D4",
31+
"#84CC16",
32+
"#F97316",
33+
"#6366F1",
34+
]
35+
36+
37+
def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
38+
"""Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
39+
color_map = {}
40+
41+
for i, run in enumerate(runs):
42+
base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
43+
44+
if smoothing:
45+
color_map[f"{run}_smoothed"] = base_color
46+
color_map[f"{run}_original"] = base_color + "4D"
47+
else:
48+
color_map[run] = base_color
49+
50+
return color_map
51+
1352

1453
def get_projects(request: gr.Request):
1554
storage = SQLiteStorage("", "", {})
@@ -43,21 +82,52 @@ def load_run_data(project: str | None, run: str | None, smoothing: bool):
4382
if not metrics:
4483
return None
4584
df = pd.DataFrame(metrics)
85+
86+
if "step" not in df.columns:
87+
df["step"] = range(len(df))
88+
4689
if smoothing:
4790
numeric_cols = df.select_dtypes(include="number").columns
4891
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
49-
df[numeric_cols] = df[numeric_cols].ewm(alpha=0.1).mean()
50-
if "step" not in df.columns:
51-
df["step"] = range(len(df))
52-
return df
92+
93+
df_original = df.copy()
94+
df_original["run"] = f"{run}_original"
95+
df_original["data_type"] = "original"
96+
97+
df_smoothed = df.copy()
98+
df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean()
99+
df_smoothed["run"] = f"{run}_smoothed"
100+
df_smoothed["data_type"] = "smoothed"
101+
102+
combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
103+
return combined_df
104+
else:
105+
df["run"] = run
106+
df["data_type"] = "original"
107+
return df
53108

54109

55-
def update_runs(project):
110+
def update_runs(project, filter_text, user_interacted_with_runs=False):
56111
if project is None:
57112
runs = []
113+
num_runs = 0
58114
else:
59115
runs = get_runs(project)
60-
return gr.Dropdown(choices=runs, value=runs)
116+
num_runs = len(runs)
117+
if filter_text:
118+
runs = [r for r in runs if filter_text in r]
119+
if not user_interacted_with_runs:
120+
return gr.CheckboxGroup(
121+
choices=runs, value=[runs[0]] if runs else []
122+
), gr.Textbox(label=f"Runs ({num_runs})")
123+
else:
124+
return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
125+
126+
127+
def filter_runs(project, filter_text):
128+
runs = get_runs(project)
129+
runs = [r for r in runs if filter_text in r]
130+
return gr.CheckboxGroup(choices=runs, value=runs)
61131

62132

63133
def toggle_timer(cb_value):
@@ -72,52 +142,107 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None:
72142
storage.log(metrics)
73143

74144

145+
def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
146+
"""
147+
Sort metrics by grouping prefixes together.
148+
Metrics without prefixes come first, then grouped by prefix.
149+
150+
Example:
151+
Input: ["train/loss", "loss", "train/acc", "val/loss"]
152+
Output: ["loss", "train/acc", "train/loss", "val/loss"]
153+
"""
154+
no_prefix = []
155+
with_prefix = []
156+
157+
for metric in metrics:
158+
if "/" in metric:
159+
with_prefix.append(metric)
160+
else:
161+
no_prefix.append(metric)
162+
163+
no_prefix.sort()
164+
165+
prefix_groups = {}
166+
for metric in with_prefix:
167+
prefix = metric.split("/")[0]
168+
if prefix not in prefix_groups:
169+
prefix_groups[prefix] = []
170+
prefix_groups[prefix].append(metric)
171+
172+
sorted_with_prefix = []
173+
for prefix in sorted(prefix_groups.keys()):
174+
sorted_with_prefix.extend(sorted(prefix_groups[prefix]))
175+
176+
return no_prefix + sorted_with_prefix
177+
178+
75179
def configure(request: gr.Request):
76180
if metrics := request.query_params.get("metrics"):
77181
return metrics.split(",")
78182
else:
79183
return []
80184

81185

82-
with gr.Blocks(theme="citrus", title="Trackio Dashboard") as demo:
186+
with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo:
83187
with gr.Sidebar() as sidebar:
84188
gr.Markdown(
85189
f"<div style='display: flex; align-items: center; gap: 8px;'><img src='/gradio_api/file={TRACKIO_LOGO_PATH}' width='32' height='32'><span style='font-size: 2em; font-weight: bold;'>Trackio</span></div>"
86190
)
87-
project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
191+
project_dd = gr.Dropdown(label="Project")
192+
run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
193+
run_cb = gr.CheckboxGroup(
194+
label="Runs", choices=[], interactive=True, elem_id="run-cb"
195+
)
196+
with gr.Sidebar(position="right", open=False) as settings_sidebar:
88197
gr.Markdown("### ⚙️ Settings")
89198
realtime_cb = gr.Checkbox(label="Refresh realtime", value=True)
90199
smoothing_cb = gr.Checkbox(label="Smoothing", value=True)
91-
with gr.Row():
92-
run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True)
93200

94201
timer = gr.Timer(value=1)
95202
metrics_subset = gr.State([])
203+
user_interacted_with_run_cb = gr.State(False)
96204

97205
gr.on(
98206
[demo.load],
99207
fn=configure,
100208
outputs=metrics_subset,
101209
)
102210
gr.on(
103-
[demo.load, timer.tick],
211+
[demo.load],
104212
fn=get_projects,
105213
outputs=project_dd,
106214
show_progress="hidden",
107215
)
108216
gr.on(
109-
[demo.load, project_dd.change, timer.tick],
217+
[timer.tick],
218+
fn=update_runs,
219+
inputs=[project_dd, run_tb, user_interacted_with_run_cb],
220+
outputs=[run_cb, run_tb],
221+
show_progress="hidden",
222+
)
223+
gr.on(
224+
[demo.load, project_dd.change],
110225
fn=update_runs,
111-
inputs=project_dd,
112-
outputs=run_dd,
226+
inputs=[project_dd, run_tb],
227+
outputs=[run_cb, run_tb],
113228
show_progress="hidden",
114229
)
230+
115231
realtime_cb.change(
116232
fn=toggle_timer,
117233
inputs=realtime_cb,
118234
outputs=timer,
119235
api_name="toggle_timer",
120236
)
237+
run_cb.input(
238+
fn=lambda: True,
239+
outputs=user_interacted_with_run_cb,
240+
)
241+
run_tb.input(
242+
fn=filter_runs,
243+
inputs=[project_dd, run_tb],
244+
outputs=run_cb,
245+
)
121246

122247
gr.api(
123248
fn=log,
@@ -132,53 +257,73 @@ def update_x_lim(select_data: gr.SelectData):
132257
@gr.render(
133258
triggers=[
134259
demo.load,
135-
run_dd.change,
260+
run_cb.change,
136261
timer.tick,
137262
smoothing_cb.change,
138263
x_lim.change,
139264
],
140-
inputs=[project_dd, run_dd, smoothing_cb, metrics_subset, x_lim],
265+
inputs=[project_dd, run_cb, smoothing_cb, metrics_subset, x_lim],
141266
)
142267
def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value):
143268
dfs = []
269+
original_runs = runs.copy()
270+
144271
for run in runs:
145272
df = load_run_data(project, run, smoothing)
146273
if df is not None:
147-
df["run"] = run
148274
dfs.append(df)
275+
149276
if dfs:
150277
master_df = pd.concat(dfs, ignore_index=True)
151278
else:
152279
master_df = pd.DataFrame()
280+
281+
if master_df.empty:
282+
return
283+
153284
numeric_cols = master_df.select_dtypes(include="number").columns
154-
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
285+
numeric_cols = [
286+
c for c in numeric_cols if c not in RESERVED_KEYS and c != "step"
287+
]
155288
if metrics_subset:
156289
numeric_cols = [c for c in numeric_cols if c in metrics_subset]
290+
numeric_cols = sort_metrics_by_prefix(list(numeric_cols))
291+
292+
color_map = get_color_mapping(original_runs, smoothing)
293+
157294
plots: list[gr.LinePlot] = []
158-
for col in range(len(numeric_cols) // 2):
295+
for col in range((len(numeric_cols) + 1) // 2):
159296
with gr.Row(key=f"row-{col}"):
160297
for i in range(2):
161-
plot = gr.LinePlot(
162-
master_df,
163-
x="step",
164-
y=numeric_cols[2 * col + i],
165-
color="run" if "run" in master_df.columns else None,
166-
title=numeric_cols[2 * col + i],
167-
key=f"plot-{col}-{i}",
168-
preserved_by_key=None,
169-
x_lim=x_lim_value,
170-
y_lim=[
171-
min(master_df[numeric_cols[2 * col + i]]),
172-
max(master_df[numeric_cols[2 * col + i]]),
173-
],
174-
show_fullscreen_button=True,
175-
)
176-
plots.append(plot)
298+
metric_idx = 2 * col + i
299+
if metric_idx < len(numeric_cols):
300+
metric_name = numeric_cols[metric_idx]
301+
302+
metric_df = master_df.dropna(subset=[metric_name])
303+
304+
if not metric_df.empty:
305+
plot = gr.LinePlot(
306+
metric_df,
307+
x="step",
308+
y=metric_name,
309+
color="run" if "run" in metric_df.columns else None,
310+
color_map=color_map,
311+
title=metric_name,
312+
key=f"plot-{col}-{i}",
313+
preserved_by_key=None,
314+
x_lim=x_lim_value,
315+
y_lim=[
316+
metric_df[metric_name].min(),
317+
metric_df[metric_name].max(),
318+
],
319+
show_fullscreen_button=True,
320+
)
321+
plots.append(plot)
177322

178323
for plot in plots:
179324
plot.select(update_x_lim, outputs=x_lim)
180325
plot.double_click(lambda: None, outputs=x_lim)
181326

182327

183328
if __name__ == "__main__":
184-
demo.launch(allowed_paths=[TRACKIO_LOGO_PATH])
329+
demo.launch(allowed_paths=[TRACKIO_LOGO_PATH], show_api=False)

trackio/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from huggingface_hub.constants import HF_HOME
88

9-
RESERVED_KEYS = ["project", "run", "timestamp"]
9+
RESERVED_KEYS = ["project", "run", "timestamp", "step"]
1010
TRACKIO_DIR = os.path.join(HF_HOME, "trackio")
1111

1212
TRACKIO_LOGO_PATH = str(Path(__file__).parent.joinpath("trackio_logo.png"))

trackio/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.8
1+
0.0.9

0 commit comments

Comments
 (0)