Skip to content

Experiment: Debug 32B Spiking #1368

@dlwh

Description

@dlwh

Related to #1295

Description

We have been dealing with the loss spikes in Marin 32B for a while now. Unfortunately, this issue is still not fully fixed. In the most recent 2.5k steps, we spent ~130 step in / recovering from spikes, which accounts for approximately 5% of the data. Also, we are concerned that the spikes may have hidden damage on the models’ capability that is not reflected in intermediate training loss.

Here is a summary of our current effort and recent plans to understand and further mitigate these spikes.

Current Observations

  • Adam update norm spikes (2x) precedes loss spikes but not every update norm spike leads to loss spikes.
  • Gradient norms typically spikes after update norm spikes
  • We have already skipped steps with extremely large gradient norms but this does not reduce time to recovery
  • Some more fine-grained preliminary observations
  • The update norm spikes are of higher magnitude on the lower layers compared to the upper layers.

When update norm spikes, the second order momentum remains mostly unchanged but the first order momentum increases by 2x

Current intuition

Because the update norm spikes precedes both loss and gradient spikes, these spikes are likely not caused by one huge gradient.
We therefore conjecture that perhaps for a few consecutive steps before the update norm spike, the gradients are more aligned and have a higher signal to noise ratio. This makes the first order momentum much larger than other steps, and lead to the high update norm.

TODO

Understanding the cause and effect of spikes

  • Replay/ablate batches before update norm blow up to try to localize to a few batches that cause the spikes and check both the examples themselves and how optimizer states evolve [@dlwh]
  • Cooldown before and after one spike to see how spikes actually impact downstream performance [@RohithKuditipudi ]
  • Add tracker for dead neurons Track "dead neurons" #1367
  • something something look at activations

####Potential Mitigation

  • Update norm clipping [@dlwh ]
  • Manually decreases signal to noise ratio for these steps by adding noise to gradients so there is an upper bound to signal to noise ratio [@RohithKuditipudi ]
  • Increase the signal to noise ratio for every step so the optimizer get used to high signal to noise ratio by leveraging Nesterov momentum [@WhenWen ]
  • Try different optimizers that have potentially higher stability, such as Muon [@dlwh ]

Hypothesis or Goal

Stabilize the 32B and understand why it destabilizes

Results

See #1390 , #1395, https://wandb.ai/marin-community/marin/reports/Marin-32B-Work-In-Progress--VmlldzoxMzM1Mzk1NQ

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions