Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/easy-apes-hammer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"trackio": minor
---

feat:Traces in Trackio
40 changes: 40 additions & 0 deletions examples/traces/basic-trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import random

import trackio

PROJECT_ID = random.randint(100000, 999999)
PROJECT_NAME = f"trace-demo-basic-{PROJECT_ID}"

examples = [
("What is 2 + 2?", "2 + 2 = 4."),
("What is the capital of Australia?", "The capital of Australia is Canberra."),
(
"Give me a one-sentence summary of Trackio.",
"Trackio is a lightweight experiment tracking dashboard for ML and agent workflows.",
),
("Translate 'hello' to Spanish.", "Hola."),
]

for run_idx in range(2):
trackio.init(project=PROJECT_NAME, name=f"basic-run-{run_idx}")

for step, (prompt, completion) in enumerate(examples):
trackio.log(
{
"trace": trackio.Trace(
messages=[
{"role": "system", "content": "You are a concise assistant."},
{"role": "user", "content": prompt},
{"role": "assistant", "content": completion},
],
metadata={
"label": f"basic-demo-{run_idx + 1}",
"category": "basic-example",
"index": step,
},
)
},
step=step,
)

trackio.finish()
74 changes: 74 additions & 0 deletions examples/traces/complex-trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import random

import numpy as np

import trackio

PROJECT_ID = random.randint(100000, 999999)
PROJECT_NAME = f"trace-demo-complex-{PROJECT_ID}"


def make_screenshot(seed: int):
rng = np.random.default_rng(seed)
return rng.integers(0, 255, size=(240, 320, 3), dtype=np.uint8)


for run_idx in range(2):
trackio.init(project=PROJECT_NAME, name=f"complex-run-{run_idx}")

for step in range(4):
screenshot = make_screenshot(run_idx * 10 + step)
trackio.log(
{
"agent_trace": trackio.Trace(
messages=[
{"role": "system", "content": "You are a browser agent."},
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Inspect page variant {step} and summarize it.",
},
trackio.Image(
screenshot,
caption=f"browser screenshot run={run_idx} step={step}",
),
],
},
{
"role": "assistant",
"content": "I will inspect the page and call a tool if needed.",
"tool_calls": [
{
"id": f"call_{run_idx}_{step}",
"type": "function",
"function": {
"name": "extract_title",
"arguments": '{"selector": "title"}',
},
}
],
},
{
"role": "tool",
"content": f'{{"title": "Trackio Demo {run_idx}-{step}"}}',
"tool_call_id": f"call_{run_idx}_{step}",
},
{
"role": "assistant",
"content": f"The page variant {step} appears to be a Trackio demo with a visible screenshot and an extracted title.",
},
],
metadata={
"label": f"complex-demo-{run_idx}",
"environment": "browser",
"category": "complex-example",
"variant": step,
},
)
},
step=step,
)

trackio.finish()
162 changes: 162 additions & 0 deletions examples/traces/trl-trace-integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# /// script
# dependencies = [
# "trackio",
# "trl",
# "datasets",
# "transformers",
# "torch",
# ]
# ///

import random

import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import SFTConfig, SFTTrainer

import trackio

PROJECT_ID = random.randint(100000, 999999)
PROJECT_NAME = f"trace-demo-trl-{PROJECT_ID}"
MODEL_NAME = "sshleifer/tiny-gpt2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

examples = [
{"prompt": "What is 2 + 2?", "reference_completion": "2 + 2 = 4."},
{
"prompt": "What color is the sky on a clear day?",
"reference_completion": "The sky is typically blue on a clear day.",
},
{
"prompt": "Translate 'good morning' to French.",
"reference_completion": "Bonjour.",
},
{
"prompt": "Name the capital of Japan.",
"reference_completion": "Tokyo is the capital of Japan.",
},
{
"prompt": "Give one use of Trackio.",
"reference_completion": "Trackio can be used to inspect training logs and traces.",
},
]


def format_example(example):
return {
"text": (
"### Instruction:\n"
f"{example['prompt']}\n\n"
"### Response:\n"
f"{example['reference_completion']}"
)
}


dataset = Dataset.from_list([format_example(example) for example in examples * 2])


class TraceLoggingCallback(TrainerCallback):
def __init__(self, prompt_examples, run_label, tokenizer):
self.prompt_examples = prompt_examples
self.run_label = run_label
self.tokenizer = tokenizer

def _generate_completion(self, model, prompt):
encoded = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=64,
)
encoded = {key: value.to(model.device) for key, value in encoded.items()}

was_training = model.training
model.eval()
with torch.no_grad():
generated = model.generate(
**encoded,
max_new_tokens=24,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
if was_training:
model.train()

prompt_length = encoded["input_ids"].shape[1]
completion_ids = generated[0][prompt_length:]
completion = self.tokenizer.decode(completion_ids, skip_special_tokens=True)
completion = completion.strip()
return completion or "(empty generation)"

def on_log(self, args, state, control, logs=None, **kwargs):
if not logs or state.global_step <= 0:
return

model = kwargs.get("model")
if model is None:
return

sample = self.prompt_examples[
(state.global_step - 1) % len(self.prompt_examples)
]
trackio.log(
{
"trace": trackio.Trace(
messages=[
{
"role": "system",
"content": "You are a supervised fine-tuning demo model.",
},
{"role": "user", "content": sample["prompt"]},
{
"role": "assistant",
"content": self._generate_completion(
model, sample["prompt"]
),
},
],
metadata={
"label": self.run_label,
"trainer": "trl-sft",
"loss": float(logs.get("loss", 0.0)),
"global_step": int(state.global_step),
"reference_completion": sample["reference_completion"],
},
)
},
step=int(state.global_step),
)


for run_idx in range(2):
run_name = f"trl-run-{run_idx}"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir=f"./trl_trace_output_{PROJECT_ID}_{run_idx}",
per_device_train_batch_size=2,
max_steps=5,
logging_steps=1,
save_strategy="no",
report_to="trackio",
project=PROJECT_NAME,
run_name=run_name,
trackio_space_id=None,
learning_rate=5e-5,
dataset_text_field="text",
max_length=64,
),
train_dataset=dataset,
processing_class=tokenizer,
callbacks=[TraceLoggingCallback(examples, run_name, tokenizer)],
)

trainer.train()
25 changes: 25 additions & 0 deletions tests/e2e-local/test_trace_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import trackio
from trackio.sqlite_storage import SQLiteStorage


def test_trace_logging_round_trip(temp_dir):
run = trackio.init(project="trace_project", name="trace_run")

run.log(
{
"trace": trackio.Trace(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "2 + 2 = 4."},
],
metadata={"label": "demo-trace"},
)
}
)
run.finish()

traces = SQLiteStorage.get_traces("trace_project", "trace_run")
assert len(traces) == 1
assert traces[0]["messages"][2]["content"] == "2 + 2 = 4."
assert traces[0]["metadata"]["label"] == "demo-trace"
6 changes: 3 additions & 3 deletions tests/ui/test_ui_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_runs_plots_images_are_displayed(temp_dir):
page.goto(full_url)
page.wait_for_load_state("networkidle")
nav_links = page.locator(".nav-link")
expect(nav_links).to_have_count(7)
expect(nav_links).to_have_count(8)

run_label = page.locator(".run-name", has_text="test_run")
expect(run_label).to_be_visible()
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_navbar_page_navigation(temp_dir):
page.goto(full_url)
page.wait_for_load_state("networkidle")
nav_links = page.locator(".nav-link")
expect(nav_links).to_have_count(7)
expect(nav_links).to_have_count(8)

expect(page.locator(".metrics-page")).to_be_visible()

Expand Down Expand Up @@ -153,7 +153,7 @@ def test_runs_table_shows_run_data(temp_dir):
page.wait_for_load_state("networkidle")

nav_links = page.locator(".nav-link")
expect(nav_links).to_have_count(7)
expect(nav_links).to_have_count(8)
page.get_by_role("button", name="Runs", exact=True).click()
page.wait_for_load_state("networkidle")

Expand Down
Loading
Loading