Add MXFP8/NVFP4 quantization, quantized model init, collator state, a…#1572
Add MXFP8/NVFP4 quantization, quantized model init, collator state, a…#1572savitha-eng wants to merge 10 commits into
Conversation
…nd 70B support - Add quantization.py with layer-wise precision (MXFP8/NVFP4/BF16 per layer) - Add quantized_model_init support with FusedAdam FP32 master weights - Add stateful TokenPackingDataset for checkpoint resume across all collators - Add 70B Llama configs with context parallelism and THD format - Add checkpoint.py with FSDP2 save/load utilities - Update modeling_llama_te.py with per-layer FP8 recipe injection - Update train_fsdp2.py and train_fsdp2_cp.py for quantized training - Add comprehensive tests for quantization and checkpoint resume Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
…emove dead code Simplify the config parameter name since FusedAdam is now the only FP32 master weights strategy. Also remove a stale guard in train_ddp.py and a duplicate config line in L2_lingua_7b_pure_bf16.yaml. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/ok to test 3b71e25 |
1 similar comment
|
/ok to test 3b71e25 |
Adding state_dict/load_state_dict causes StatefulDataLoader to switch from fast-forward replay to stateful restore, which produces incorrect batches because the packing generator state is not serializable. Fast-forward replay works correctly with the deterministic shuffle seed. Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
|
/ok to test f68ff4e |
…estore pytest - Rename L2_lingua_7b → L2_lingua_8b (model is Llama-3.1-8B) - Rename L2_lingua_7b_mxfp8_qinit → L2_lingua_8b_mxfp8_qinit - Keep only 4 example configs: 8b base, 8b mxfp8+qinit, 70b base, 70b mxfp8+qinit - Remove 11 experiment-specific configs (bf16_baseline, fp8, mxfp8, thd, cp4, etc.) - Restore pytest to requirements.txt (needed by CI runner) Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
|
/ok to test a611fb4 |
- Restore is_compileable property on HFInferenceParams (accidentally dropped from PR 1500), required by newer transformers generate(). - Unify get_autocast_context init path to work both standalone (model tests, no outer context) and with outer quantized_model_init (recipe training). FP8/FP4 layers use per-layer quantized_model_init with preserve_high_precision_init_val=True; BF16 layers use quantized_model_init(enabled=False) to override any outer context. Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
|
/ok to test 0bfaceb |
…h xfail Match the pattern used by model-level tests and conftest.py: parametrize across DelayedScaling, Float8CurrentScaling, Float8BlockScaling, and MXFP8BlockScaling with automatic xfail for unsupported hardware. Previously hardcoded Float8BlockScaling which requires sm90+ (Hopper) but CI runs on L4 (sm89). Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
71a24ee to
0134793
Compare
|
/ok to test 0134793 |
Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
Recover llama3_8gpu_tflops.png from PR #1500 branch and copy lingua-1b-loss-curve.png and lingua-1b-step-time.png from images/recipes/ to images/llama3/ to match README paths. Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
- Remove _log_per_layer_gradient_norms and gradient_debug config (debug-only) - Remove _log_memory helper and all call sites - Remove no-op MixedPrecisionPolicy (FP32 master weights handled by FusedAdam) - Remove ValueError catch for StatefulDataLoader fast-forward workaround Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
|
/ok to test c4a8068 |
Add single-node and multi-node MXFP8 vs BF16 throughput comparisons for Llama-3.1-8B and 70B with quantized model init analysis and W&B run links. Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
|
/ok to test da14eaa |
Summary
Builds on top of PR #1500 (
jm/mxfp8-nvfp4-llama3) with additional features, CI fixes, and benchmark documentation for thellama3_native_terecipe.Key changes on top of PR #1500
MixedPrecisionPolicyapproach with TE'sFusedAdam(master_weight_dtype=torch.float32)for mixed-precision training — simpler, better supported for FP8/MXFP8/NVFP4preserve_high_precision_init_val: Stores BF16 copies of init values when usingte.quantized_model_init, needed for FP32 master weight seeding in FP8 trainingget_autocast_context(init=True)now works both standalone (model tests, no outer context) and under an outerte.quantized_model_initcontext (recipe training) — BF16 layers exit the outer FP8 context viaquantized_model_init(enabled=False)layer_precisionconfig allows per-layer FP8/MXFP8/NVFP4/BF16 assignment (e.g., first/last layer BF16 for stability)NVFP4BlockScalingrecipe alongside MXFP8xfailfor unsupported hardware — matching existing codebase patternsis_compileableproperty: Required by HuggingFacetransformersgenerate()auto-compile check7b→8bconfigs, removed experiment configs, restored pytest markersMXFP8 Performance Benchmarks
Headline: MXFP8 vs BF16 throughput uplift (single B300 node)
Key findings:
qinit) adds ~0.8 pp on 8B but +9.7 pp on 70B — the per-layer quantize/dequantize work saved by qinit scales with depth (80 vs 32 layers). On 70B, MXFP8 + qinit delivers +38.4% throughput gain over BF16 on a single B300 node.Single-node detail: per-model 3-way comparisons
Llama-3.1-8B (1 node / 8× B300 SXM6 AC, mbs=4, gbs=32 seqs / 262k tokens, seq_len=8192):
MXFP8 + qinit (+31.1%) and MXFP8 without qinit (+30.4%) deliver essentially the same throughput gain — at 32 layers the per-layer quantize/dequantize saving is small.
Llama-3.1-70B (1 node / 8× B300 SXM6 AC, mbs=1, cp=2, dp=4, gbs=4 seqs, seq_len=8192):
MXFP8 + qinit (+39.4%) pulls ahead of MXFP8 without qinit (+28.7%) — a ~10 pp gap that doesn't appear at 8B. With 80 transformer layers, avoiding per-step quantize/dequantize adds up.
preserve_high_precision_init_val=True(HPIV) is within 1% of qinit-without-HPIV, so HPIV is essentially free at steady state.Multi-node throughput (B200, production-scale runs)
Llama-3.1-8B (8 nodes / 64× B200, mbs=2, grad_acc=2, gbs=256, seq_len=8192):
MXFP8 + qinit: 22,517 tokens/s/GPU vs 17,644 BF16 — +27.6% throughput (×1.28 speedup)
Llama-3.1-70B (4 nodes / 32× B200, cp=2, dp=16, mbs=1, gbs=16, seq_len=8192):
MXFP8 + qinit: 2,725 tokens/s/GPU vs 1,972 BF16 — +38.2% throughput (×1.40 speedup)
Wandb run links
Test plan
test_quantized_model_init.py— 4 tests × 4 recipes = 16 test cases (8 pass on L4, 8 xfail for Hopper/Blackwell-only recipes)check_copied_files.pypasses — all 3modeling_llama_te.pycopies are identicalType of changes
CI Pipeline Configuration
Note
By default, only basic unit tests are run. Add appropriate labels to enable additional test coverage.