Related to #1295
Description
We have been dealing with the loss spikes in Marin 32B for a while now. Unfortunately, this issue is still not fully fixed. In the most recent 2.5k steps, we spent ~130 step in / recovering from spikes, which accounts for approximately 5% of the data. Also, we are concerned that the spikes may have hidden damage on the models’ capability that is not reflected in intermediate training loss.
Here is a summary of our current effort and recent plans to understand and further mitigate these spikes.
Current Observations
- Adam update norm spikes (2x) precedes loss spikes but not every update norm spike leads to loss spikes.
- Gradient norms typically spikes after update norm spikes
- We have already skipped steps with extremely large gradient norms but this does not reduce time to recovery
- Some more fine-grained preliminary observations
- The update norm spikes are of higher magnitude on the lower layers compared to the upper layers.
When update norm spikes, the second order momentum remains mostly unchanged but the first order momentum increases by 2x
Current intuition
Because the update norm spikes precedes both loss and gradient spikes, these spikes are likely not caused by one huge gradient.
We therefore conjecture that perhaps for a few consecutive steps before the update norm spike, the gradients are more aligned and have a higher signal to noise ratio. This makes the first order momentum much larger than other steps, and lead to the high update norm.
TODO
Understanding the cause and effect of spikes
####Potential Mitigation
Hypothesis or Goal
Stabilize the 32B and understand why it destabilizes
Results
See #1390 , #1395, https://wandb.ai/marin-community/marin/reports/Marin-32B-Work-In-Progress--VmlldzoxMzM1Mzk1NQ
Related to #1295
Description
We have been dealing with the loss spikes in Marin 32B for a while now. Unfortunately, this issue is still not fully fixed. In the most recent 2.5k steps, we spent ~130 step in / recovering from spikes, which accounts for approximately 5% of the data. Also, we are concerned that the spikes may have hidden damage on the models’ capability that is not reflected in intermediate training loss.
Here is a summary of our current effort and recent plans to understand and further mitigate these spikes.
Current Observations
When update norm spikes, the second order momentum remains mostly unchanged but the first order momentum increases by 2x
Current intuition
Because the update norm spikes precedes both loss and gradient spikes, these spikes are likely not caused by one huge gradient.
We therefore conjecture that perhaps for a few consecutive steps before the update norm spike, the gradients are more aligned and have a higher signal to noise ratio. This makes the first order momentum much larger than other steps, and lead to the high update norm.
TODO
Understanding the cause and effect of spikes
####Potential Mitigation
Hypothesis or Goal
Stabilize the 32B and understand why it destabilizes
Results
See #1390 , #1395, https://wandb.ai/marin-community/marin/reports/Marin-32B-Work-In-Progress--VmlldzoxMzM1Mzk1NQ