Skip to content

Commit 3ab8978

Browse files
authored
Merge pull request #12 from gradio-app/ema
Add a checkbox for EMA-based smoothing
2 parents a46ab50 + e8c1138 commit 3ab8978

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

trackio/ui.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,18 @@ def get_runs(project):
3131
return storage.get_runs(project)
3232

3333

34-
def load_run_data(project, run):
34+
def load_run_data(project: str | None, run: str | None, smoothing: bool):
3535
if not project or not run:
3636
return None
3737
storage = SQLiteStorage("", "", {})
3838
metrics = storage.get_metrics(project, run)
3939
if not metrics:
4040
return None
4141
df = pd.DataFrame(metrics)
42+
if smoothing:
43+
numeric_cols = df.select_dtypes(include="number").columns
44+
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
45+
df[numeric_cols] = df[numeric_cols].ewm(alpha=0.1).mean()
4246
if "step" not in df.columns:
4347
df["step"] = range(len(df))
4448
return df
@@ -64,14 +68,15 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None:
6468
storage.log(metrics)
6569

6670

67-
with gr.Blocks(theme="citrus") as demo:
71+
with gr.Blocks(theme="citrus", title="Trackio Dashboard") as demo:
6872
with gr.Sidebar() as sidebar:
6973
gr.Markdown(
7074
f"<div style='display: flex; align-items: center; gap: 8px;'><img src='/gradio_api/file={TRACKIO_LOGO_PATH}' width='32' height='32'><span style='font-size: 2em; font-weight: bold;'>Trackio</span></div>"
7175
)
7276
project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
7377
gr.Markdown("### ⚙️ Settings")
7478
realtime_cb = gr.Checkbox(label="Refresh realtime", value=True)
79+
smoothing_cb = gr.Checkbox(label="Smoothing", value=True)
7580
with gr.Row():
7681
run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True)
7782

@@ -103,13 +108,13 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None:
103108
)
104109

105110
@gr.render(
106-
triggers=[run_dd.change, timer.tick],
107-
inputs=[project_dd, run_dd],
111+
triggers=[run_dd.change, timer.tick, smoothing_cb.change],
112+
inputs=[project_dd, run_dd, smoothing_cb],
108113
)
109-
def update_dashboard(project, runs):
114+
def update_dashboard(project, runs, smoothing):
110115
dfs = []
111116
for run in runs:
112-
df = load_run_data(project, run)
117+
df = load_run_data(project, run, smoothing)
113118
if df is not None:
114119
df["run"] = run
115120
dfs.append(df)

0 commit comments

Comments
 (0)