Skip to content

Commit 6b7dfdb

Browse files
committed
changes
1 parent 22881db commit 6b7dfdb

2 files changed

Lines changed: 99 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use the Trackio CLI to analyze the training project spike-demo? Any issues you see?
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import math
2+
import random
3+
4+
import trackio as wandb
5+
6+
STEPS = 12000
7+
SPIKE_STEP = 8000
8+
9+
# LR schedule: warmup, then a bump at SPIKE_STEP
10+
def get_learning_rate(step):
11+
if step < 1000:
12+
return 1e-4 * (step / 1000) # linear warmup
13+
elif step < SPIKE_STEP:
14+
return 1e-4
15+
elif step < SPIKE_STEP + 50:
16+
# Sudden bump — simulates a bad schedule or bug
17+
return 4e-4
18+
else:
19+
# Slowly recover
20+
steps_after = step - (SPIKE_STEP + 50)
21+
return max(1e-4, 4e-4 * math.exp(-steps_after / 500))
22+
23+
24+
# Weight norm drifts upward slowly before the spike
25+
def get_weight_norm(step):
26+
base = 2.0 + (step / STEPS) * 4.0 # slow drift from 2 -> 6
27+
noise = random.gauss(0, 0.05)
28+
if step >= SPIKE_STEP and step < SPIKE_STEP + 200:
29+
# Norms jump when gradients explode
30+
surge = 8.0 * math.exp(-(step - SPIKE_STEP) / 80)
31+
return base + surge + noise
32+
return base + noise
33+
34+
35+
# Gradient norm is stable, then explodes at spike, then recovers
36+
def get_grad_norm(step):
37+
if step < SPIKE_STEP:
38+
base = 1.0 + random.gauss(0, 0.1)
39+
return max(0.1, base)
40+
elif step < SPIKE_STEP + 30:
41+
# Explosion
42+
peak = 500.0 * math.exp(-(step - SPIKE_STEP) / 15)
43+
return peak + random.gauss(0, 10)
44+
else:
45+
# Gradual recovery
46+
steps_after = step - (SPIKE_STEP + 30)
47+
base = 1.0 + 20.0 * math.exp(-steps_after / 300)
48+
return max(0.1, base + random.gauss(0, 0.2))
49+
50+
51+
# Loss decreases smoothly, spikes at SPIKE_STEP, then recovers
52+
def get_loss(step):
53+
# Healthy decreasing curve
54+
progress = min(step, SPIKE_STEP) / SPIKE_STEP
55+
healthy_loss = 2.5 * math.exp(-3 * progress) + 0.1
56+
noise = random.gauss(0, 0.02)
57+
58+
if step < SPIKE_STEP:
59+
return max(0.05, healthy_loss + noise)
60+
elif step < SPIKE_STEP + 50:
61+
# Spike
62+
spike_magnitude = 3.5 * math.exp(-(step - SPIKE_STEP) / 20)
63+
return healthy_loss + spike_magnitude + abs(noise)
64+
else:
65+
# Recover, but lands slightly worse than pre-spike
66+
steps_after = step - (SPIKE_STEP + 50)
67+
recovery_loss = (healthy_loss + 0.3) * math.exp(-steps_after / 800) + 0.15
68+
return max(0.1, recovery_loss + noise)
69+
70+
71+
wandb.init(
72+
project=f"spike-demo",
73+
name="run-0",
74+
config=dict(
75+
total_steps=STEPS,
76+
spike_step=SPIKE_STEP,
77+
base_lr=1e-4,
78+
spike_lr=4e-4,
79+
),
80+
)
81+
82+
for step in range(STEPS):
83+
lr = get_learning_rate(step)
84+
weight_norm = get_weight_norm(step)
85+
grad_norm = get_grad_norm(step)
86+
loss = get_loss(step)
87+
88+
wandb.log(
89+
{
90+
"train/loss": round(loss, 4),
91+
"train/grad_norm": round(grad_norm, 4),
92+
"train/weight_norm": round(weight_norm, 4),
93+
"train/learning_rate": lr,
94+
},
95+
step=step,
96+
)
97+
98+
wandb.finish()

0 commit comments

Comments
 (0)