Skip to content

Commit f88c131

Browse files
committed
feat: combine multi-GPU metrics onto single plots + rank prefix for distributed training
1 parent c7aa38d commit f88c131

2 files changed

Lines changed: 169 additions & 124 deletions

File tree

trackio/gpu.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
_energy_baseline: dict[int, float] = {}
1414

1515

16+
def _get_rank_prefix() -> str:
17+
"""Get rank prefix for distributed training. Returns empty string if not distributed."""
18+
rank = os.environ.get("RANK") or os.environ.get("SLURM_PROCID")
19+
if rank is not None:
20+
return f"rank{rank}/"
21+
return ""
22+
23+
1624
def _ensure_pynvml():
1725
global PYNVML_AVAILABLE, pynvml
1826
if PYNVML_AVAILABLE:
@@ -142,9 +150,10 @@ def collect_gpu_metrics(device: int | None = None) -> dict:
142150
total_power = 0.0
143151
max_temp = 0.0
144152
valid_util_count = 0
153+
rank_prefix = _get_rank_prefix()
145154

146155
for logical_idx, physical_idx in gpu_indices:
147-
prefix = f"gpu/{logical_idx}"
156+
prefix = f"{rank_prefix}gpu/{logical_idx}"
148157
try:
149158
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)
150159

@@ -286,13 +295,13 @@ def collect_gpu_metrics(device: int | None = None) -> dict:
286295
continue
287296

288297
if valid_util_count > 0:
289-
metrics["gpu/mean_utilization"] = total_util / valid_util_count
298+
metrics[f"{rank_prefix}gpu/mean_utilization"] = total_util / valid_util_count
290299
if total_mem_used_gib > 0:
291-
metrics["gpu/total_allocated_memory"] = total_mem_used_gib
300+
metrics[f"{rank_prefix}gpu/total_allocated_memory"] = total_mem_used_gib
292301
if total_power > 0:
293-
metrics["gpu/total_power"] = total_power
302+
metrics[f"{rank_prefix}gpu/total_power"] = total_power
294303
if max_temp > 0:
295-
metrics["gpu/max_temp"] = max_temp
304+
metrics[f"{rank_prefix}gpu/max_temp"] = max_temp
296305

297306
return metrics
298307

trackio/ui/system_page.py

Lines changed: 155 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""The System Metrics page for the Trackio UI (GPU metrics, etc.)."""
22

3+
import re
4+
35
import gradio as gr
46
import pandas as pd
57

@@ -9,6 +11,58 @@
911
from trackio.ui.components.colored_checkbox import ColoredCheckboxGroup
1012
from trackio.ui.helpers.run_selection import RunSelection
1113

14+
GPU_METRIC_PATTERN = re.compile(r"^((?:rank\d+/)?gpu)/(\d+)/(.+)$")
15+
16+
17+
def melt_gpu_metrics(
18+
df: pd.DataFrame, metric_cols: list[str]
19+
) -> tuple[pd.DataFrame, dict[str, list[str]]]:
20+
"""
21+
Transform GPU metrics to combine all GPUs onto single plots.
22+
23+
Returns:
24+
- Modified dataframe with melted GPU metrics
25+
- Dict mapping metric type (e.g. "utilization") to list of original column names
26+
"""
27+
gpu_groups: dict[str, list[str]] = {}
28+
non_gpu_cols = []
29+
30+
for col in metric_cols:
31+
match = GPU_METRIC_PATTERN.match(col)
32+
if match:
33+
prefix, gpu_id, metric_type = match.groups()
34+
key = f"{prefix}/{metric_type}"
35+
if key not in gpu_groups:
36+
gpu_groups[key] = []
37+
gpu_groups[key].append((col, gpu_id))
38+
else:
39+
non_gpu_cols.append(col)
40+
41+
if not gpu_groups:
42+
return df, {}
43+
44+
melted_dfs = []
45+
base_cols = ["time", "run"] if "run" in df.columns else ["time"]
46+
47+
for metric_key, cols_and_ids in gpu_groups.items():
48+
for original_col, gpu_id in cols_and_ids:
49+
if original_col not in df.columns:
50+
continue
51+
subset = df[base_cols + [original_col]].copy()
52+
subset = subset.dropna(subset=[original_col])
53+
if subset.empty:
54+
continue
55+
subset["gpu_id"] = gpu_id
56+
subset = subset.rename(columns={original_col: metric_key})
57+
melted_dfs.append(subset)
58+
59+
if melted_dfs:
60+
melted_df = pd.concat(melted_dfs, ignore_index=True)
61+
non_gpu_df = df[base_cols + non_gpu_cols].copy() if non_gpu_cols else None
62+
return melted_df, {k: [c for c, _ in v] for k, v in gpu_groups.items()}
63+
64+
return df, {}
65+
1266

1367
def get_runs(project) -> list[str]:
1468
if not project:
@@ -172,152 +226,134 @@ def update_system_dashboard(
172226
numeric_cols = master_df.select_dtypes(include="number").columns
173227
numeric_cols = [c for c in numeric_cols if c not in ["time", "timestamp"]]
174228

229+
melted_df, gpu_metric_groups = melt_gpu_metrics(master_df, list(numeric_cols))
230+
non_gpu_cols = [c for c in numeric_cols if not GPU_METRIC_PATTERN.match(c)]
231+
175232
if smoothing_granularity > 0:
176233
window_size = max(3, min(smoothing_granularity, len(master_df)))
177-
for col in numeric_cols:
178-
master_df[col] = master_df.groupby("run")[col].transform(
179-
lambda x: x.rolling(
180-
window=window_size, center=True, min_periods=1
181-
).mean()
182-
)
234+
for col in non_gpu_cols:
235+
if col in master_df.columns:
236+
master_df[col] = master_df.groupby("run")[col].transform(
237+
lambda x: x.rolling(
238+
window=window_size, center=True, min_periods=1
239+
).mean()
240+
)
183241

184-
ordered_groups, nested_metric_groups = utils.order_metrics_by_plot_preference(
185-
list(numeric_cols)
186-
)
187242
all_runs = selection.choices if selection else original_runs
188243
color_map = utils.get_color_mapping(all_runs, False)
244+
gpu_ids = sorted(melted_df["gpu_id"].unique()) if "gpu_id" in melted_df.columns else []
245+
gpu_color_map = {gid: utils.get_color_palette()[i % len(utils.get_color_palette())] for i, gid in enumerate(gpu_ids)}
189246

190247
metric_idx = 0
191-
for group_name in ordered_groups:
192-
group_data = nested_metric_groups[group_name]
193-
194-
total_plot_count = sum(
195-
1
196-
for m in group_data["direct_metrics"]
197-
if not master_df.dropna(subset=[m]).empty
198-
) + sum(
199-
sum(1 for m in metrics if not master_df.dropna(subset=[m]).empty)
200-
for metrics in group_data["subgroups"].values()
201-
)
202-
group_label = (
203-
f"{group_name} ({total_plot_count})"
204-
if total_plot_count > 0
205-
else group_name
206-
)
207248

249+
if gpu_metric_groups:
250+
gpu_plot_count = len(gpu_metric_groups)
208251
with gr.Accordion(
209-
label=group_label,
252+
label=f"gpu ({gpu_plot_count})",
210253
open=True,
211-
key=f"sys-accordion-{group_name}",
254+
key="sys-accordion-gpu-combined",
212255
preserved_by_key=["value", "open"],
213256
):
214-
if group_data["direct_metrics"]:
215-
with gr.Draggable(
216-
key=f"sys-row-{group_name}-direct", orientation="row"
217-
):
218-
for metric_name in group_data["direct_metrics"]:
219-
metric_df = master_df.dropna(subset=[metric_name])
220-
color = "run" if "run" in metric_df.columns else None
221-
downsampled_df, updated_x_lim = utils.downsample(
222-
metric_df,
223-
x_column,
224-
metric_name,
225-
color,
226-
x_lim_value,
227-
)
228-
if not metric_df.empty:
229-
plot = gr.LinePlot(
230-
downsampled_df,
231-
x=x_column,
232-
y=metric_name,
233-
x_title="Time (seconds)",
234-
y_title=metric_name.split("/")[-1],
235-
color=color,
236-
color_map=color_map,
237-
colors_in_legend=original_runs,
238-
title=metric_name,
239-
key=f"sys-plot-{metric_idx}",
240-
preserved_by_key=None,
241-
buttons=["fullscreen", "export"],
242-
x_lim=updated_x_lim,
243-
min_width=400,
244-
)
245-
plot.select(
246-
update_x_lim,
247-
outputs=x_lim,
248-
key=f"sys-select-{metric_idx}",
257+
with gr.Draggable(key="sys-row-gpu-combined", orientation="row"):
258+
for metric_key in sorted(gpu_metric_groups.keys()):
259+
metric_df = melted_df[melted_df[metric_key].notna()].copy() if metric_key in melted_df.columns else pd.DataFrame()
260+
if metric_df.empty:
261+
continue
262+
color = "gpu_id"
263+
downsampled_df, updated_x_lim = utils.downsample(
264+
metric_df, x_column, metric_key, color, x_lim_value
265+
)
266+
metric_display = metric_key.split("/")[-1]
267+
plot = gr.LinePlot(
268+
downsampled_df,
269+
x=x_column,
270+
y=metric_key,
271+
x_title="Time (seconds)",
272+
y_title=metric_display,
273+
color=color,
274+
color_map=gpu_color_map,
275+
title=metric_key,
276+
key=f"sys-plot-{metric_idx}",
277+
preserved_by_key=None,
278+
buttons=["fullscreen", "export"],
279+
x_lim=updated_x_lim,
280+
min_width=400,
281+
)
282+
plot.select(update_x_lim, outputs=x_lim, key=f"sys-select-{metric_idx}")
283+
plot.double_click(lambda: None, outputs=x_lim, key=f"sys-double-{metric_idx}")
284+
metric_idx += 1
285+
286+
if non_gpu_cols:
287+
ordered_groups, nested_metric_groups = utils.order_metrics_by_plot_preference(non_gpu_cols)
288+
for group_name in ordered_groups:
289+
group_data = nested_metric_groups[group_name]
290+
total_plot_count = sum(
291+
1 for m in group_data["direct_metrics"] if not master_df.dropna(subset=[m]).empty
292+
) + sum(
293+
sum(1 for m in metrics if not master_df.dropna(subset=[m]).empty)
294+
for metrics in group_data["subgroups"].values()
295+
)
296+
if total_plot_count == 0:
297+
continue
298+
group_label = f"{group_name} ({total_plot_count})"
299+
300+
with gr.Accordion(
301+
label=group_label, open=True,
302+
key=f"sys-accordion-{group_name}", preserved_by_key=["value", "open"],
303+
):
304+
if group_data["direct_metrics"]:
305+
with gr.Draggable(key=f"sys-row-{group_name}-direct", orientation="row"):
306+
for metric_name in group_data["direct_metrics"]:
307+
metric_df = master_df.dropna(subset=[metric_name])
308+
if metric_df.empty:
309+
continue
310+
color = "run" if "run" in metric_df.columns else None
311+
downsampled_df, updated_x_lim = utils.downsample(
312+
metric_df, x_column, metric_name, color, x_lim_value
249313
)
250-
plot.double_click(
251-
lambda: None,
252-
outputs=x_lim,
253-
key=f"sys-double-{metric_idx}",
314+
plot = gr.LinePlot(
315+
downsampled_df, x=x_column, y=metric_name,
316+
x_title="Time (seconds)", y_title=metric_name.split("/")[-1],
317+
color=color, color_map=color_map, colors_in_legend=original_runs,
318+
title=metric_name, key=f"sys-plot-{metric_idx}",
319+
preserved_by_key=None, buttons=["fullscreen", "export"],
320+
x_lim=updated_x_lim, min_width=400,
254321
)
255-
metric_idx += 1
322+
plot.select(update_x_lim, outputs=x_lim, key=f"sys-select-{metric_idx}")
323+
plot.double_click(lambda: None, outputs=x_lim, key=f"sys-double-{metric_idx}")
324+
metric_idx += 1
256325

257-
if group_data["subgroups"]:
258326
for subgroup_name in sorted(group_data["subgroups"].keys()):
259327
subgroup_metrics = group_data["subgroups"][subgroup_name]
260-
261328
subgroup_plot_count = sum(
262-
1
263-
for m in subgroup_metrics
264-
if not master_df.dropna(subset=[m]).empty
329+
1 for m in subgroup_metrics if not master_df.dropna(subset=[m]).empty
265330
)
266-
subgroup_label = (
267-
f"{subgroup_name} ({subgroup_plot_count})"
268-
if subgroup_plot_count > 0
269-
else subgroup_name
270-
)
271-
331+
if subgroup_plot_count == 0:
332+
continue
272333
with gr.Accordion(
273-
label=subgroup_label,
274-
open=True,
334+
label=f"{subgroup_name} ({subgroup_plot_count})", open=True,
275335
key=f"sys-accordion-{group_name}-{subgroup_name}",
276336
preserved_by_key=["value", "open"],
277337
):
278-
with gr.Draggable(
279-
key=f"sys-row-{group_name}-{subgroup_name}",
280-
orientation="row",
281-
):
338+
with gr.Draggable(key=f"sys-row-{group_name}-{subgroup_name}", orientation="row"):
282339
for metric_name in subgroup_metrics:
283340
metric_df = master_df.dropna(subset=[metric_name])
284-
color = (
285-
"run" if "run" in metric_df.columns else None
286-
)
341+
if metric_df.empty:
342+
continue
343+
color = "run" if "run" in metric_df.columns else None
287344
downsampled_df, updated_x_lim = utils.downsample(
288-
metric_df,
289-
x_column,
290-
metric_name,
291-
color,
292-
x_lim_value,
345+
metric_df, x_column, metric_name, color, x_lim_value
346+
)
347+
plot = gr.LinePlot(
348+
downsampled_df, x=x_column, y=metric_name,
349+
x_title="Time (seconds)", y_title=metric_name.split("/")[-1],
350+
color=color, color_map=color_map, colors_in_legend=original_runs,
351+
title=metric_name, key=f"sys-plot-{metric_idx}",
352+
preserved_by_key=None, buttons=["fullscreen", "export"],
353+
x_lim=updated_x_lim, min_width=400,
293354
)
294-
if not metric_df.empty:
295-
plot = gr.LinePlot(
296-
downsampled_df,
297-
x=x_column,
298-
y=metric_name,
299-
x_title="Time (seconds)",
300-
y_title=metric_name.split("/")[-1],
301-
color=color,
302-
color_map=color_map,
303-
colors_in_legend=original_runs,
304-
title=metric_name,
305-
key=f"sys-plot-{metric_idx}",
306-
preserved_by_key=None,
307-
buttons=["fullscreen", "export"],
308-
x_lim=updated_x_lim,
309-
min_width=400,
310-
)
311-
plot.select(
312-
update_x_lim,
313-
outputs=x_lim,
314-
key=f"sys-select-{metric_idx}",
315-
)
316-
plot.double_click(
317-
lambda: None,
318-
outputs=x_lim,
319-
key=f"sys-double-{metric_idx}",
320-
)
355+
plot.select(update_x_lim, outputs=x_lim, key=f"sys-select-{metric_idx}")
356+
plot.double_click(lambda: None, outputs=x_lim, key=f"sys-double-{metric_idx}")
321357
metric_idx += 1
322358

323359
gr.on(

0 commit comments

Comments
 (0)