Skip to content

Experiment: Try out different remat strategies to get the 70b working on fewer slices #906

@dlwh

Description

@dlwh

Description

The 70b run just barely doesn't fit on 4x v6e-128 with 12M batch size. If we use checkpointing scan or CPU offload we should get higher throughput. let's figure out which.

(Right now we can only get 4)

Hypothesis or Goal

Find the fastest remat strategy that fits.

Links

(Delete any that aren't applicable)

Results

1.4b

  • CPU offload crashes with an internal assert. not gonna worry about it for now.
  • nested remat (two levels of scan) where you remat inside each block i.e. scan(remat(scan(remat(f), n=layers/block_size), n=block_size) gives a 20% MFU hit unless you set things so that it functionally does nothing (fold16 unroll)
Image

Thinking about it now, 20% is more or less exactly what it should be. normal forward backward is ~3 * Layers * F, where F is the number of flops to do forward in one block. if you remat each block (our normal strategy) you need to do one additional forward pass in each block so you get ~4 * Layers * F. with nested remat where you remat the innermost f, you have to do forward another time, giving ~5 * Layers * F which is ~20% slower than no nesting.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions