Description
The 70b run just barely doesn't fit on 4x v6e-128 with 12M batch size. If we use checkpointing scan or CPU offload we should get higher throughput. let's figure out which.
(Right now we can only get 4)
Hypothesis or Goal
Find the fastest remat strategy that fits.
Links
(Delete any that aren't applicable)
Results
1.4b
- CPU offload crashes with an internal assert. not gonna worry about it for now.
- nested remat (two levels of scan) where you remat inside each block i.e.
scan(remat(scan(remat(f), n=layers/block_size), n=block_size) gives a 20% MFU hit unless you set things so that it functionally does nothing (fold16 unroll)
Thinking about it now, 20% is more or less exactly what it should be. normal forward backward is ~3 * Layers * F, where F is the number of flops to do forward in one block. if you remat each block (our normal strategy) you need to do one additional forward pass in each block so you get ~4 * Layers * F. with nested remat where you remat the innermost f, you have to do forward another time, giving ~5 * Layers * F which is ~20% slower than no nesting.
Description
The 70b run just barely doesn't fit on 4x v6e-128 with 12M batch size. If we use checkpointing scan or CPU offload we should get higher throughput. let's figure out which.
(Right now we can only get 4)
Hypothesis or Goal
Find the fastest remat strategy that fits.
Links
(Delete any that aren't applicable)
Results
1.4b
scan(remat(scan(remat(f), n=layers/block_size), n=block_size)gives a 20% MFU hit unless you set things so that it functionally does nothing (fold16 unroll)Thinking about it now, 20% is more or less exactly what it should be. normal forward backward is ~
3 * Layers * F, where F is the number of flops to do forward in one block. if you remat each block (our normal strategy) you need to do one additional forward pass in each block so you get ~4 * Layers * F. with nested remat where you remat the innermost f, you have to do forward another time, giving ~5 * Layers * Fwhich is ~20% slower than no nesting.