Skip to content

Commit 88ab26f

Browse files
abidlabsclaudeqgallouedec
authored
Add inline notebook support for trackio.show() (#204)
* Add inline notebook support for trackio.show() Adds automatic notebook detection and displays the Trackio dashboard inline when running in Jupyter notebooks, Google Colab, or similar environments. Uses Gradio's inline parameter to embed the dashboard directly in notebook cells instead of opening in a separate browser window. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * changes * format * changes * Add support for launching Trackio dashboard in Jupyter Notebooks --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 11692c5 commit 88ab26f

4 files changed

Lines changed: 216 additions & 7 deletions

File tree

docs/source/launch.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,15 @@ trackio.show(theme="soft")
7070
</hfoptions>
7171

7272
To see the available themes, check out the [themes gallery](https://huggingface.co/spaces/gradio/theme-gallery).
73+
74+
## Launching a Dashboard in Jupyter Notebooks
75+
76+
You can also launch the dashboard directly within a Jupyter Notebook. Just use the same command as above:
77+
78+
```py
79+
import trackio
80+
81+
trackio.show()
82+
```
83+
84+
Check the [demo notebook](https://github.com/gradio-app/trackio/blob/main/examples/notebook_integration.ipynb).
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "a254e37b",
6+
"metadata": {},
7+
"source": [
8+
"## Trackio in Jupyter Notebooks\n",
9+
"\n",
10+
"This notebook demonstrates how to log training metrics with TrackIO, and display the results interactively directly in the notebook."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "7fe95112",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"!pip install -q trackio"
21+
]
22+
},
23+
{
24+
"cell_type": "markdown",
25+
"id": "daa46a91",
26+
"metadata": {},
27+
"source": [
28+
"The following code simulates training runs by generating synthetic loss, accuracy, and gradient norm metrics over multiple epochs and logs these metrics to Trackio."
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"id": "12cce6ee",
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"import math\n",
39+
"import random\n",
40+
"import time\n",
41+
"\n",
42+
"import trackio as wandb\n",
43+
"\n",
44+
"EPOCHS = 20\n",
45+
"PROJECT_ID = random.randint(100000, 999999)\n",
46+
"\n",
47+
"\n",
48+
"def generate_loss_curve(epoch, max_epochs, base_loss=2.5, min_loss=0.1):\n",
49+
" \"\"\"Generate a realistic loss curve that decreases over time with noise\"\"\"\n",
50+
" progress = epoch / max_epochs\n",
51+
" base_curve = base_loss * math.exp(-3 * progress) + min_loss\n",
52+
"\n",
53+
" noise_scale = 0.3 * (1 - progress * 0.7)\n",
54+
" noise = random.gauss(0, noise_scale)\n",
55+
"\n",
56+
" return max(min_loss * 0.5, base_curve + noise)\n",
57+
"\n",
58+
"\n",
59+
"def generate_accuracy_curve(epoch, max_epochs, max_acc=0.95, min_acc=0.1):\n",
60+
" \"\"\"Generate a realistic accuracy curve that increases over time with noise\"\"\"\n",
61+
" progress = epoch / max_epochs\n",
62+
" base_curve = max_acc / (1 + math.exp(-6 * (progress - 0.5))) + min_acc\n",
63+
"\n",
64+
" noise_scale = 0.08 * (1 - progress * 0.5)\n",
65+
" noise = random.gauss(0, noise_scale)\n",
66+
"\n",
67+
" return max(0, min(max_acc, base_curve + noise))\n",
68+
"\n",
69+
"\n",
70+
"def generate_grad_norm_curve(epoch, max_epochs):\n",
71+
" \"\"\"Generate a gradient norm that starts at infinity and decreases to reasonable values\"\"\"\n",
72+
" if epoch == 0:\n",
73+
" return float(\"inf\")\n",
74+
" elif epoch == 1:\n",
75+
" return 1000.0\n",
76+
" elif epoch == 2:\n",
77+
" return 100.0\n",
78+
" else:\n",
79+
" progress = (epoch - 2) / (max_epochs - 2)\n",
80+
" base_value = 50 * math.exp(-4 * progress) + 1.0\n",
81+
" noise = random.gauss(0, 0.5)\n",
82+
" return max(0.1, base_value + noise)\n",
83+
"\n",
84+
"\n",
85+
"for run in range(3):\n",
86+
" wandb.init(\n",
87+
" project=f\"fake-training-{PROJECT_ID}\",\n",
88+
" name=f\"test-run-{run}\",\n",
89+
" config=dict(\n",
90+
" epochs=EPOCHS,\n",
91+
" learning_rate=0.001,\n",
92+
" batch_size=32,\n",
93+
" ),\n",
94+
" )\n",
95+
"\n",
96+
" for epoch in range(EPOCHS):\n",
97+
" train_loss = generate_loss_curve(\n",
98+
" epoch,\n",
99+
" EPOCHS,\n",
100+
" base_loss=random.uniform(2.5, 3.5),\n",
101+
" min_loss=random.uniform(0.05, 0.15),\n",
102+
" )\n",
103+
" val_loss = generate_loss_curve(\n",
104+
" epoch,\n",
105+
" EPOCHS,\n",
106+
" base_loss=random.uniform(2.5, 3.5),\n",
107+
" min_loss=random.uniform(0.05, 0.15),\n",
108+
" )\n",
109+
"\n",
110+
" train_accuracy = generate_accuracy_curve(\n",
111+
" epoch,\n",
112+
" EPOCHS,\n",
113+
" max_acc=random.uniform(0.7, 0.9),\n",
114+
" min_acc=random.uniform(0.1, 0.3),\n",
115+
" )\n",
116+
" val_accuracy = generate_accuracy_curve(\n",
117+
" epoch,\n",
118+
" EPOCHS,\n",
119+
" max_acc=random.uniform(0.7, 0.9),\n",
120+
" min_acc=random.uniform(0.1, 0.3),\n",
121+
" )\n",
122+
"\n",
123+
" grad_norm = generate_grad_norm_curve(epoch, EPOCHS)\n",
124+
"\n",
125+
" if epoch > 2 and random.random() < 0.3:\n",
126+
" val_loss *= 1.1\n",
127+
" val_accuracy *= 0.95\n",
128+
"\n",
129+
" wandb.log(\n",
130+
" {\n",
131+
" \"train/loss\": round(train_loss, 4),\n",
132+
" \"train/accuracy\": round(train_accuracy, 4),\n",
133+
" \"train/rewards/reward1\": random.random(),\n",
134+
" \"train/rewards/reward2\": random.random(),\n",
135+
" \"val/loss\": round(val_loss, 4),\n",
136+
" \"val/accuracy\": round(val_accuracy, 4),\n",
137+
" \"grad_norm\": grad_norm,\n",
138+
" }\n",
139+
" )\n",
140+
"\n",
141+
" time.sleep(0.2)\n",
142+
"\n",
143+
"wandb.finish()"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"id": "c471846d",
149+
"metadata": {},
150+
"source": [
151+
"The following cell launches the TrackIO dashboard directly in the notebook, allowing to interactively explore your logged training metrics."
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"id": "957f6b6e",
158+
"metadata": {},
159+
"outputs": [],
160+
"source": [
161+
"import trackio\n",
162+
"\n",
163+
"trackio.show()"
164+
]
165+
}
166+
],
167+
"metadata": {
168+
"kernelspec": {
169+
"display_name": "trl",
170+
"language": "python",
171+
"name": "python3"
172+
},
173+
"language_info": {
174+
"name": "python",
175+
"version": "3.12.11"
176+
}
177+
},
178+
"nbformat": 4,
179+
"nbformat_minor": 5
180+
}

trackio/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,16 @@ def show(project: str | None = None, theme: str | ThemeClass = DEFAULT_THEME):
251251
_, url, share_url = demo.launch(
252252
show_api=False,
253253
quiet=True,
254-
inline=False,
254+
inline=utils.is_in_notebook(),
255255
prevent_thread_lock=True,
256256
favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
257257
allowed_paths=[TRACKIO_LOGO_DIR],
258258
)
259259

260260
base_url = share_url + "/" if share_url else url
261261
dashboard_url = base_url + f"?project={project}" if project else base_url
262-
print(f"* Trackio UI launched at: {dashboard_url}")
263-
webbrowser.open(dashboard_url)
264-
utils.block_except_in_notebook()
262+
263+
if not utils.is_in_notebook():
264+
print(f"* Trackio UI launched at: {dashboard_url}")
265+
webbrowser.open(dashboard_url)
266+
utils.block_except_in_notebook()

trackio/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
import os
33
import re
4-
import sys
54
import time
65
from pathlib import Path
76
from typing import TYPE_CHECKING
@@ -242,9 +241,25 @@ def generate_readable_name(used_names: list[str], space_id: str | None = None) -
242241
return name
243242

244243

244+
def is_in_notebook():
245+
"""
246+
Detect if code is running in a notebook environment (Jupyter, Colab, etc.).
247+
"""
248+
try:
249+
from IPython import get_ipython
250+
251+
if get_ipython() is not None:
252+
return get_ipython().__class__.__name__ in [
253+
"ZMQInteractiveShell", # Jupyter notebook/lab
254+
"Shell", # IPython terminal
255+
] or "google.colab" in str(get_ipython())
256+
except ImportError:
257+
pass
258+
return False
259+
260+
245261
def block_except_in_notebook():
246-
in_notebook = bool(getattr(sys, "ps1", sys.flags.interactive))
247-
if in_notebook:
262+
if is_in_notebook():
248263
return
249264
try:
250265
while True:

0 commit comments

Comments
 (0)