Skip to content

Commit 9fe1616

Browse files
committed
changes
1 parent 53b4e57 commit 9fe1616

3 files changed

Lines changed: 60 additions & 66 deletions

File tree

trackio/storage.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,32 @@
1+
import csv
12
import json
23
import os
3-
import csv
44

5-
import pandas as pd
6-
from datasets import Dataset
5+
from trackio.utils import RESERVED_KEYS, TRACKIO_DIR
76

87

98
class TrackioStorage:
109
def __init__(self, project: str, name: str, config: dict):
1110
self.project = project
1211
self.name = name
1312
self.config = config
14-
self.dir = os.path.join("trackio", project, self.name)
13+
self.dir = os.path.join(TRACKIO_DIR, project, self.name)
1514
os.makedirs(self.dir, exist_ok=True)
1615
self.csv_path = os.path.join(self.dir, "run.csv")
1716
self.parquet_path = os.path.join(self.dir, "run.parquet")
1817
self.config_path = os.path.join(self.dir, "config.json")
19-
self.headers = None
18+
self.headers = []
2019

2120
with open(self.config_path, "w") as f:
2221
json.dump(self.config, f, indent=2)
2322

2423
def log(self, metrics: dict):
24+
for k in metrics.keys():
25+
if k in RESERVED_KEYS or k.startswith("__"):
26+
raise ValueError(
27+
f"Please do not use this reserved key as a metric: {k}"
28+
)
29+
2530
if not os.path.exists(self.csv_path) or os.path.getsize(self.csv_path) == 0:
2631
self.headers = list(metrics.keys())
2732
with open(self.csv_path, "w", newline="") as f:
@@ -52,8 +57,9 @@ def log(self, metrics: dict):
5257
writer = csv.DictWriter(f, fieldnames=self.headers)
5358
writer.writerow({h: metrics.get(h, "") for h in self.headers})
5459

55-
# def finish(self):
56-
# if self.logs:
57-
# df = pd.DataFrame(self.logs)
58-
# ds = Dataset.from_pandas(df)
59-
# ds.to_parquet(self.run_path)
60+
def finish(self):
61+
pass
62+
# if self.logs:
63+
# df = pd.DataFrame(self.logs)
64+
# ds = Dataset.from_pandas(df)
65+
# ds.to_parquet(self.run_path)

trackio/ui.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import json
21
import os
32

43
import gradio as gr
54
import pandas as pd
6-
import plotly.express as px
75

8-
TRACKIO_DIR = "trackio"
6+
from trackio.utils import RESERVED_KEYS, TRACKIO_DIR
97

108

119
def get_projects():
@@ -31,68 +29,56 @@ def get_runs(project):
3129

3230
def load_run_data(project, run):
3331
run_dir = os.path.join(TRACKIO_DIR, project, run)
34-
run_path = os.path.join(run_dir, "run.parquet")
35-
config_path = os.path.join(run_dir, "config.json")
32+
csv_path = os.path.join(run_dir, "run.csv")
3633
df = None
37-
config = {}
38-
if os.path.exists(run_path):
39-
df = pd.read_parquet(run_path)
40-
if os.path.exists(config_path):
41-
with open(config_path) as f:
42-
config = json.load(f)
43-
return df, config
44-
45-
46-
def plot_metrics(df):
47-
if df is None or df.empty:
48-
return None
49-
plots = []
50-
numeric_cols = df.select_dtypes(include="number").columns
51-
for col in numeric_cols:
52-
fig = px.line(df, y=col, title=col)
53-
plots.append(fig)
54-
return plots
34+
if os.path.exists(csv_path):
35+
df = pd.read_csv(csv_path)
36+
df["step"] = range(len(df))
37+
return df
5538

5639

5740
def update_runs(project):
58-
return gr.Dropdown(choices=get_runs(project), value=None)
59-
60-
61-
def update_dashboard(project, run):
62-
if not project or not run:
63-
return (
64-
gr.update(visible=False),
65-
gr.update(visible=False),
66-
gr.update(visible=False),
67-
)
68-
df, config = load_run_data(project, run)
69-
plots = plot_metrics(df)
70-
return plots, gr.JSON(value=config), gr.update(visible=True)
41+
runs = get_runs(project)
42+
return gr.Dropdown(choices=runs, value=runs)
7143

7244

7345
def launch_ui():
74-
with gr.Blocks() as demo:
75-
gr.Markdown("# Trackio Dashboard")
76-
with gr.Row():
46+
with gr.Blocks(theme="citrus") as demo:
47+
with gr.Sidebar():
48+
gr.Markdown("# 🎯 Trackio Dashboard")
7749
project_dd = gr.Dropdown(label="Project", choices=get_projects())
78-
run_dd = gr.Dropdown(label="Run", choices=[])
7950
with gr.Row():
80-
plot_output = gr.LinePlot(
81-
label="Metrics",
82-
visible=False,
83-
show_label=True,
84-
elem_id="metrics-plot",
85-
interactive=True,
86-
scale=2,
87-
)
88-
config_output = gr.JSON(label="Config", visible=False)
89-
# Events
90-
project_dd.change(
91-
fn=lambda p: update_runs(p), inputs=project_dd, outputs=run_dd
51+
run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True)
52+
53+
gr.on(
54+
[demo.load, project_dd.change],
55+
fn=update_runs,
56+
inputs=project_dd,
57+
outputs=run_dd,
9258
)
93-
run_dd.change(
94-
fn=lambda p, r: update_dashboard(p, r),
59+
60+
@gr.render(
61+
triggers=[run_dd.change],
9562
inputs=[project_dd, run_dd],
96-
outputs=[plot_output, config_output, plot_output],
9763
)
98-
demo.launch()
64+
def update_dashboard(project, runs):
65+
dfs = []
66+
for run in runs:
67+
df = load_run_data(project, run)
68+
if df is not None:
69+
df["run"] = run
70+
dfs.append(df)
71+
if dfs:
72+
master_df = pd.concat(dfs, ignore_index=True)
73+
else:
74+
master_df = pd.DataFrame()
75+
numeric_cols = master_df.select_dtypes(include="number").columns
76+
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
77+
for col in numeric_cols:
78+
gr.LinePlot(master_df, x="step", y=col, color="run" if "run" in master_df.columns else None, title=col)
79+
80+
demo.launch(show_api=False)
81+
82+
83+
if __name__ == "__main__":
84+
launch_ui()

trackio/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import random
22

3+
RESERVED_KEYS = ["step", "epoch", "batch", "run", "timestamp"]
4+
TRACKIO_DIR = ".trackio"
35

46
def generate_readable_name():
57
"""

0 commit comments

Comments
 (0)