1+ import math
2+ import random
3+
4+ import trackio as wandb
5+
6+ STEPS = 12000
7+ SPIKE_STEP = 8000
8+
9+ # LR schedule: warmup, then a bump at SPIKE_STEP
10+ def get_learning_rate (step ):
11+ if step < 1000 :
12+ return 1e-4 * (step / 1000 ) # linear warmup
13+ elif step < SPIKE_STEP :
14+ return 1e-4
15+ elif step < SPIKE_STEP + 50 :
16+ # Sudden bump — simulates a bad schedule or bug
17+ return 4e-4
18+ else :
19+ # Slowly recover
20+ steps_after = step - (SPIKE_STEP + 50 )
21+ return max (1e-4 , 4e-4 * math .exp (- steps_after / 500 ))
22+
23+
24+ # Weight norm drifts upward slowly before the spike
25+ def get_weight_norm (step ):
26+ base = 2.0 + (step / STEPS ) * 4.0 # slow drift from 2 -> 6
27+ noise = random .gauss (0 , 0.05 )
28+ if step >= SPIKE_STEP and step < SPIKE_STEP + 200 :
29+ # Norms jump when gradients explode
30+ surge = 8.0 * math .exp (- (step - SPIKE_STEP ) / 80 )
31+ return base + surge + noise
32+ return base + noise
33+
34+
35+ # Gradient norm is stable, then explodes at spike, then recovers
36+ def get_grad_norm (step ):
37+ if step < SPIKE_STEP :
38+ base = 1.0 + random .gauss (0 , 0.1 )
39+ return max (0.1 , base )
40+ elif step < SPIKE_STEP + 30 :
41+ # Explosion
42+ peak = 500.0 * math .exp (- (step - SPIKE_STEP ) / 15 )
43+ return peak + random .gauss (0 , 10 )
44+ else :
45+ # Gradual recovery
46+ steps_after = step - (SPIKE_STEP + 30 )
47+ base = 1.0 + 20.0 * math .exp (- steps_after / 300 )
48+ return max (0.1 , base + random .gauss (0 , 0.2 ))
49+
50+
51+ # Loss decreases smoothly, spikes at SPIKE_STEP, then recovers
52+ def get_loss (step ):
53+ # Healthy decreasing curve
54+ progress = min (step , SPIKE_STEP ) / SPIKE_STEP
55+ healthy_loss = 2.5 * math .exp (- 3 * progress ) + 0.1
56+ noise = random .gauss (0 , 0.02 )
57+
58+ if step < SPIKE_STEP :
59+ return max (0.05 , healthy_loss + noise )
60+ elif step < SPIKE_STEP + 50 :
61+ # Spike
62+ spike_magnitude = 3.5 * math .exp (- (step - SPIKE_STEP ) / 20 )
63+ return healthy_loss + spike_magnitude + abs (noise )
64+ else :
65+ # Recover, but lands slightly worse than pre-spike
66+ steps_after = step - (SPIKE_STEP + 50 )
67+ recovery_loss = (healthy_loss + 0.3 ) * math .exp (- steps_after / 800 ) + 0.15
68+ return max (0.1 , recovery_loss + noise )
69+
70+
71+ wandb .init (
72+ project = f"spike-demo" ,
73+ name = "run-0" ,
74+ config = dict (
75+ total_steps = STEPS ,
76+ spike_step = SPIKE_STEP ,
77+ base_lr = 1e-4 ,
78+ spike_lr = 4e-4 ,
79+ ),
80+ )
81+
82+ for step in range (STEPS ):
83+ lr = get_learning_rate (step )
84+ weight_norm = get_weight_norm (step )
85+ grad_norm = get_grad_norm (step )
86+ loss = get_loss (step )
87+
88+ wandb .log (
89+ {
90+ "train/loss" : round (loss , 4 ),
91+ "train/grad_norm" : round (grad_norm , 4 ),
92+ "train/weight_norm" : round (weight_norm , 4 ),
93+ "train/learning_rate" : lr ,
94+ },
95+ step = step ,
96+ )
97+
98+ wandb .finish ()
0 commit comments