|
7 | 7 | from trackio.utils import RESERVED_KEYS |
8 | 8 |
|
9 | 9 |
|
10 | | -def get_projects(): |
| 10 | +def get_projects(request: gr.Request): |
11 | 11 | storage = SQLiteStorage("", "", {}) |
12 | 12 | projects = storage.get_projects() |
| 13 | + if project := request.query_params.get("project"): |
| 14 | + pass |
| 15 | + else: |
| 16 | + project = projects[0] if projects else None |
13 | 17 | return gr.Dropdown( |
14 | | - label="Project", choices=projects, value=projects[0] if projects else None |
| 18 | + label="Project", |
| 19 | + choices=projects, |
| 20 | + value=project, |
| 21 | + allow_custom_value=True, |
15 | 22 | ) |
16 | 23 |
|
17 | 24 |
|
@@ -55,70 +62,71 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None: |
55 | 62 | storage.log(metrics) |
56 | 63 |
|
57 | 64 |
|
58 | | -def launch_gradio() -> str: |
59 | | - with gr.Blocks(theme="citrus") as demo: |
60 | | - with gr.Sidebar(): |
61 | | - gr.Markdown("# 🎯 Trackio Dashboard") |
62 | | - project_dd = gr.Dropdown(label="Project") |
63 | | - gr.Markdown("### ⚙️ Settings") |
64 | | - realtime_cb = gr.Checkbox(label="Refresh realtime", value=True) |
65 | | - with gr.Row(): |
66 | | - run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True) |
67 | | - |
68 | | - timer = gr.Timer(value=1) |
69 | | - |
70 | | - gr.on( |
71 | | - [demo.load, timer.tick], |
72 | | - fn=get_projects, |
73 | | - outputs=project_dd, |
74 | | - show_progress="hidden", |
75 | | - ) |
76 | | - gr.on( |
77 | | - [demo.load, project_dd.change, timer.tick], |
78 | | - fn=update_runs, |
79 | | - inputs=project_dd, |
80 | | - outputs=run_dd, |
81 | | - show_progress="hidden", |
82 | | - ) |
83 | | - realtime_cb.change( |
84 | | - fn=toggle_timer, |
85 | | - inputs=realtime_cb, |
86 | | - outputs=timer, |
87 | | - api_name="toggle_timer", |
88 | | - ) |
89 | | - |
90 | | - gr.api( |
91 | | - fn=log, |
92 | | - api_name="log", |
93 | | - ) |
94 | | - |
95 | | - @gr.render( |
96 | | - triggers=[run_dd.change, timer.tick], |
97 | | - inputs=[project_dd, run_dd], |
98 | | - ) |
99 | | - def update_dashboard(project, runs): |
100 | | - dfs = [] |
101 | | - for run in runs: |
102 | | - df = load_run_data(project, run) |
103 | | - if df is not None: |
104 | | - df["run"] = run |
105 | | - dfs.append(df) |
106 | | - if dfs: |
107 | | - master_df = pd.concat(dfs, ignore_index=True) |
108 | | - else: |
109 | | - master_df = pd.DataFrame() |
110 | | - numeric_cols = master_df.select_dtypes(include="number").columns |
111 | | - numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] |
112 | | - for col in numeric_cols: |
113 | | - gr.LinePlot( |
114 | | - master_df, |
115 | | - x="step", |
116 | | - y=col, |
117 | | - color="run" if "run" in master_df.columns else None, |
118 | | - title=col, |
119 | | - ) |
120 | | - |
121 | | - _, url, _ = demo.launch(show_api=False, inline=False, quiet=True) |
| 65 | +with gr.Blocks(theme="citrus") as demo: |
| 66 | + with gr.Sidebar(): |
| 67 | + gr.Markdown("# 🎯 Trackio Dashboard") |
| 68 | + project_dd = gr.Dropdown(label="Project") |
| 69 | + gr.Markdown("### ⚙️ Settings") |
| 70 | + realtime_cb = gr.Checkbox(label="Refresh realtime", value=True) |
| 71 | + with gr.Row(): |
| 72 | + run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True) |
| 73 | + |
| 74 | + timer = gr.Timer(value=1) |
| 75 | + |
| 76 | + gr.on( |
| 77 | + [demo.load, timer.tick], |
| 78 | + fn=get_projects, |
| 79 | + outputs=project_dd, |
| 80 | + show_progress="hidden", |
| 81 | + ) |
| 82 | + gr.on( |
| 83 | + [demo.load, project_dd.change, timer.tick], |
| 84 | + fn=update_runs, |
| 85 | + inputs=project_dd, |
| 86 | + outputs=run_dd, |
| 87 | + show_progress="hidden", |
| 88 | + ) |
| 89 | + realtime_cb.change( |
| 90 | + fn=toggle_timer, |
| 91 | + inputs=realtime_cb, |
| 92 | + outputs=timer, |
| 93 | + api_name="toggle_timer", |
| 94 | + ) |
| 95 | + |
| 96 | + gr.api( |
| 97 | + fn=log, |
| 98 | + api_name="log", |
| 99 | + ) |
| 100 | + |
| 101 | + @gr.render( |
| 102 | + triggers=[run_dd.change, timer.tick], |
| 103 | + inputs=[project_dd, run_dd], |
| 104 | + ) |
| 105 | + def update_dashboard(project, runs): |
| 106 | + dfs = [] |
| 107 | + for run in runs: |
| 108 | + df = load_run_data(project, run) |
| 109 | + if df is not None: |
| 110 | + df["run"] = run |
| 111 | + dfs.append(df) |
| 112 | + if dfs: |
| 113 | + master_df = pd.concat(dfs, ignore_index=True) |
| 114 | + else: |
| 115 | + master_df = pd.DataFrame() |
| 116 | + numeric_cols = master_df.select_dtypes(include="number").columns |
| 117 | + numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] |
| 118 | + for col in numeric_cols: |
| 119 | + gr.LinePlot( |
| 120 | + master_df, |
| 121 | + x="step", |
| 122 | + y=col, |
| 123 | + color="run" if "run" in master_df.columns else None, |
| 124 | + title=col, |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +def launch_gradio(**kwargs) -> str: |
| 129 | + _, url, _ = demo.launch(**kwargs) |
122 | 130 | return url |
123 | 131 |
|
124 | 132 |
|
|
0 commit comments