1010 from sqlite_storage import SQLiteStorage
1111 from utils import RESERVED_KEYS , TRACKIO_LOGO_PATH
1212
13+ css = """
14+ #run-cb .wrap {
15+ gap: 2px;
16+ }
17+ #run-cb .wrap label {
18+ line-height: 1;
19+ padding: 6px;
20+ }
21+ """
22+
23+ COLOR_PALETTE = [
24+ "#3B82F6" ,
25+ "#EF4444" ,
26+ "#10B981" ,
27+ "#F59E0B" ,
28+ "#8B5CF6" ,
29+ "#EC4899" ,
30+ "#06B6D4" ,
31+ "#84CC16" ,
32+ "#F97316" ,
33+ "#6366F1" ,
34+ ]
35+
36+
37+ def get_color_mapping (runs : list [str ], smoothing : bool ) -> dict [str , str ]:
38+ """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
39+ color_map = {}
40+
41+ for i , run in enumerate (runs ):
42+ base_color = COLOR_PALETTE [i % len (COLOR_PALETTE )]
43+
44+ if smoothing :
45+ color_map [f"{ run } _smoothed" ] = base_color
46+ color_map [f"{ run } _original" ] = base_color + "4D"
47+ else :
48+ color_map [run ] = base_color
49+
50+ return color_map
51+
1352
1453def get_projects (request : gr .Request ):
1554 storage = SQLiteStorage ("" , "" , {})
@@ -43,21 +82,52 @@ def load_run_data(project: str | None, run: str | None, smoothing: bool):
4382 if not metrics :
4483 return None
4584 df = pd .DataFrame (metrics )
85+
86+ if "step" not in df .columns :
87+ df ["step" ] = range (len (df ))
88+
4689 if smoothing :
4790 numeric_cols = df .select_dtypes (include = "number" ).columns
4891 numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS ]
49- df [numeric_cols ] = df [numeric_cols ].ewm (alpha = 0.1 ).mean ()
50- if "step" not in df .columns :
51- df ["step" ] = range (len (df ))
52- return df
92+
93+ df_original = df .copy ()
94+ df_original ["run" ] = f"{ run } _original"
95+ df_original ["data_type" ] = "original"
96+
97+ df_smoothed = df .copy ()
98+ df_smoothed [numeric_cols ] = df_smoothed [numeric_cols ].ewm (alpha = 0.1 ).mean ()
99+ df_smoothed ["run" ] = f"{ run } _smoothed"
100+ df_smoothed ["data_type" ] = "smoothed"
101+
102+ combined_df = pd .concat ([df_original , df_smoothed ], ignore_index = True )
103+ return combined_df
104+ else :
105+ df ["run" ] = run
106+ df ["data_type" ] = "original"
107+ return df
53108
54109
55- def update_runs (project ):
110+ def update_runs (project , filter_text , user_interacted_with_runs = False ):
56111 if project is None :
57112 runs = []
113+ num_runs = 0
58114 else :
59115 runs = get_runs (project )
60- return gr .Dropdown (choices = runs , value = runs )
116+ num_runs = len (runs )
117+ if filter_text :
118+ runs = [r for r in runs if filter_text in r ]
119+ if not user_interacted_with_runs :
120+ return gr .CheckboxGroup (
121+ choices = runs , value = [runs [0 ]] if runs else []
122+ ), gr .Textbox (label = f"Runs ({ num_runs } )" )
123+ else :
124+ return gr .CheckboxGroup (choices = runs ), gr .Textbox (label = f"Runs ({ num_runs } )" )
125+
126+
127+ def filter_runs (project , filter_text ):
128+ runs = get_runs (project )
129+ runs = [r for r in runs if filter_text in r ]
130+ return gr .CheckboxGroup (choices = runs , value = runs )
61131
62132
63133def toggle_timer (cb_value ):
@@ -72,52 +142,107 @@ def log(project: str, run: str, metrics: dict[str, Any]) -> None:
72142 storage .log (metrics )
73143
74144
145+ def sort_metrics_by_prefix (metrics : list [str ]) -> list [str ]:
146+ """
147+ Sort metrics by grouping prefixes together.
148+ Metrics without prefixes come first, then grouped by prefix.
149+
150+ Example:
151+ Input: ["train/loss", "loss", "train/acc", "val/loss"]
152+ Output: ["loss", "train/acc", "train/loss", "val/loss"]
153+ """
154+ no_prefix = []
155+ with_prefix = []
156+
157+ for metric in metrics :
158+ if "/" in metric :
159+ with_prefix .append (metric )
160+ else :
161+ no_prefix .append (metric )
162+
163+ no_prefix .sort ()
164+
165+ prefix_groups = {}
166+ for metric in with_prefix :
167+ prefix = metric .split ("/" )[0 ]
168+ if prefix not in prefix_groups :
169+ prefix_groups [prefix ] = []
170+ prefix_groups [prefix ].append (metric )
171+
172+ sorted_with_prefix = []
173+ for prefix in sorted (prefix_groups .keys ()):
174+ sorted_with_prefix .extend (sorted (prefix_groups [prefix ]))
175+
176+ return no_prefix + sorted_with_prefix
177+
178+
75179def configure (request : gr .Request ):
76180 if metrics := request .query_params .get ("metrics" ):
77181 return metrics .split ("," )
78182 else :
79183 return []
80184
81185
82- with gr .Blocks (theme = "citrus" , title = "Trackio Dashboard" ) as demo :
186+ with gr .Blocks (theme = "citrus" , title = "Trackio Dashboard" , css = css ) as demo :
83187 with gr .Sidebar () as sidebar :
84188 gr .Markdown (
85189 f"<div style='display: flex; align-items: center; gap: 8px;'><img src='/gradio_api/file={ TRACKIO_LOGO_PATH } ' width='32' height='32'><span style='font-size: 2em; font-weight: bold;'>Trackio</span></div>"
86190 )
87- project_dd = gr .Dropdown (label = "Project" , allow_custom_value = True )
191+ project_dd = gr .Dropdown (label = "Project" )
192+ run_tb = gr .Textbox (label = "Runs" , placeholder = "Type to filter..." )
193+ run_cb = gr .CheckboxGroup (
194+ label = "Runs" , choices = [], interactive = True , elem_id = "run-cb"
195+ )
196+ with gr .Sidebar (position = "right" , open = False ) as settings_sidebar :
88197 gr .Markdown ("### ⚙️ Settings" )
89198 realtime_cb = gr .Checkbox (label = "Refresh realtime" , value = True )
90199 smoothing_cb = gr .Checkbox (label = "Smoothing" , value = True )
91- with gr .Row ():
92- run_dd = gr .Dropdown (label = "Run" , choices = [], multiselect = True )
93200
94201 timer = gr .Timer (value = 1 )
95202 metrics_subset = gr .State ([])
203+ user_interacted_with_run_cb = gr .State (False )
96204
97205 gr .on (
98206 [demo .load ],
99207 fn = configure ,
100208 outputs = metrics_subset ,
101209 )
102210 gr .on (
103- [demo .load , timer . tick ],
211+ [demo .load ],
104212 fn = get_projects ,
105213 outputs = project_dd ,
106214 show_progress = "hidden" ,
107215 )
108216 gr .on (
109- [demo .load , project_dd .change , timer .tick ],
217+ [timer .tick ],
218+ fn = update_runs ,
219+ inputs = [project_dd , run_tb , user_interacted_with_run_cb ],
220+ outputs = [run_cb , run_tb ],
221+ show_progress = "hidden" ,
222+ )
223+ gr .on (
224+ [demo .load , project_dd .change ],
110225 fn = update_runs ,
111- inputs = project_dd ,
112- outputs = run_dd ,
226+ inputs = [ project_dd , run_tb ] ,
227+ outputs = [ run_cb , run_tb ] ,
113228 show_progress = "hidden" ,
114229 )
230+
115231 realtime_cb .change (
116232 fn = toggle_timer ,
117233 inputs = realtime_cb ,
118234 outputs = timer ,
119235 api_name = "toggle_timer" ,
120236 )
237+ run_cb .input (
238+ fn = lambda : True ,
239+ outputs = user_interacted_with_run_cb ,
240+ )
241+ run_tb .input (
242+ fn = filter_runs ,
243+ inputs = [project_dd , run_tb ],
244+ outputs = run_cb ,
245+ )
121246
122247 gr .api (
123248 fn = log ,
@@ -132,53 +257,73 @@ def update_x_lim(select_data: gr.SelectData):
132257 @gr .render (
133258 triggers = [
134259 demo .load ,
135- run_dd .change ,
260+ run_cb .change ,
136261 timer .tick ,
137262 smoothing_cb .change ,
138263 x_lim .change ,
139264 ],
140- inputs = [project_dd , run_dd , smoothing_cb , metrics_subset , x_lim ],
265+ inputs = [project_dd , run_cb , smoothing_cb , metrics_subset , x_lim ],
141266 )
142267 def update_dashboard (project , runs , smoothing , metrics_subset , x_lim_value ):
143268 dfs = []
269+ original_runs = runs .copy ()
270+
144271 for run in runs :
145272 df = load_run_data (project , run , smoothing )
146273 if df is not None :
147- df ["run" ] = run
148274 dfs .append (df )
275+
149276 if dfs :
150277 master_df = pd .concat (dfs , ignore_index = True )
151278 else :
152279 master_df = pd .DataFrame ()
280+
281+ if master_df .empty :
282+ return
283+
153284 numeric_cols = master_df .select_dtypes (include = "number" ).columns
154- numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS ]
285+ numeric_cols = [
286+ c for c in numeric_cols if c not in RESERVED_KEYS and c != "step"
287+ ]
155288 if metrics_subset :
156289 numeric_cols = [c for c in numeric_cols if c in metrics_subset ]
290+ numeric_cols = sort_metrics_by_prefix (list (numeric_cols ))
291+
292+ color_map = get_color_mapping (original_runs , smoothing )
293+
157294 plots : list [gr .LinePlot ] = []
158- for col in range (len (numeric_cols ) // 2 ):
295+ for col in range (( len (numeric_cols ) + 1 ) // 2 ):
159296 with gr .Row (key = f"row-{ col } " ):
160297 for i in range (2 ):
161- plot = gr .LinePlot (
162- master_df ,
163- x = "step" ,
164- y = numeric_cols [2 * col + i ],
165- color = "run" if "run" in master_df .columns else None ,
166- title = numeric_cols [2 * col + i ],
167- key = f"plot-{ col } -{ i } " ,
168- preserved_by_key = None ,
169- x_lim = x_lim_value ,
170- y_lim = [
171- min (master_df [numeric_cols [2 * col + i ]]),
172- max (master_df [numeric_cols [2 * col + i ]]),
173- ],
174- show_fullscreen_button = True ,
175- )
176- plots .append (plot )
298+ metric_idx = 2 * col + i
299+ if metric_idx < len (numeric_cols ):
300+ metric_name = numeric_cols [metric_idx ]
301+
302+ metric_df = master_df .dropna (subset = [metric_name ])
303+
304+ if not metric_df .empty :
305+ plot = gr .LinePlot (
306+ metric_df ,
307+ x = "step" ,
308+ y = metric_name ,
309+ color = "run" if "run" in metric_df .columns else None ,
310+ color_map = color_map ,
311+ title = metric_name ,
312+ key = f"plot-{ col } -{ i } " ,
313+ preserved_by_key = None ,
314+ x_lim = x_lim_value ,
315+ y_lim = [
316+ metric_df [metric_name ].min (),
317+ metric_df [metric_name ].max (),
318+ ],
319+ show_fullscreen_button = True ,
320+ )
321+ plots .append (plot )
177322
178323 for plot in plots :
179324 plot .select (update_x_lim , outputs = x_lim )
180325 plot .double_click (lambda : None , outputs = x_lim )
181326
182327
183328if __name__ == "__main__" :
184- demo .launch (allowed_paths = [TRACKIO_LOGO_PATH ])
329+ demo .launch (allowed_paths = [TRACKIO_LOGO_PATH ], show_api = False )
0 commit comments