Skip to content

Add MXFP8/NVFP4 quantization, quantized model init, collator state, a…#1572

Open
savitha-eng wants to merge 10 commits into
mainfrom
savitha/lingua-7b-fp8-clean-pr
Open

Add MXFP8/NVFP4 quantization, quantized model init, collator state, a…#1572
savitha-eng wants to merge 10 commits into
mainfrom
savitha/lingua-7b-fp8-clean-pr

Conversation

@savitha-eng
Copy link
Copy Markdown
Collaborator

@savitha-eng savitha-eng commented May 12, 2026

Summary

Builds on top of PR #1500 (jm/mxfp8-nvfp4-llama3) with additional features, CI fixes, and benchmark documentation for the llama3_native_te recipe.

Key changes on top of PR #1500

  • FusedAdam with FP32 master weights: Replaces MixedPrecisionPolicy approach with TE's FusedAdam(master_weight_dtype=torch.float32) for mixed-precision training — simpler, better supported for FP8/MXFP8/NVFP4
  • Quantized model init with preserve_high_precision_init_val: Stores BF16 copies of init values when using te.quantized_model_init, needed for FP32 master weight seeding in FP8 training
  • Unified per-layer init path: get_autocast_context(init=True) now works both standalone (model tests, no outer context) and under an outer te.quantized_model_init context (recipe training) — BF16 layers exit the outer FP8 context via quantized_model_init(enabled=False)
  • Layer-wise precision control: layer_precision config allows per-layer FP8/MXFP8/NVFP4/BF16 assignment (e.g., first/last layer BF16 for stability)
  • NVFP4 support: Added NVFP4BlockScaling recipe alongside MXFP8
  • 70B configs: Added Llama-3.1-70B hydra configs with context parallelism and THD input format
  • CI test fixes: Parametrized all FP8 tests across recipes (DelayedScaling, Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling) with automatic xfail for unsupported hardware — matching existing codebase patterns
  • Restored is_compileable property: Required by HuggingFace transformers generate() auto-compile check
  • Hydra config cleanup: Renamed 7b8b configs, removed experiment configs, restored pytest markers

MXFP8 Performance Benchmarks

Headline: MXFP8 vs BF16 throughput uplift (single B300 node)

MXFP8 throughput uplift on 8B vs 70B

Key findings:

  • Single-node: MXFP8 over BF16 gives ~30% throughput uplift on both 8B and 70B. Quantized model init (qinit) adds ~0.8 pp on 8B but +9.7 pp on 70B — the per-layer quantize/dequantize work saved by qinit scales with depth (80 vs 32 layers). On 70B, MXFP8 + qinit delivers +38.4% throughput gain over BF16 on a single B300 node.
  • Multi-node 8B (8 nodes / 64× B200): MXFP8 + qinit reaches 22,517 tokens/s/GPU vs 17,644 BF16 — +27.6% throughput (×1.28 speedup, −21.7% step time).
  • Multi-node 70B (4 nodes / 32× B200): MXFP8 + qinit reaches 2,725 tokens/s/GPU vs 1,972 BF16 — +38.2% throughput (×1.40 speedup, −27.6% step time). The larger relative gain on 70B vs 8B at scale matches the size-dependent qinit pattern from single-node.
Single-node detail: per-model 3-way comparisons

Llama-3.1-8B (1 node / 8× B300 SXM6 AC, mbs=4, gbs=32 seqs / 262k tokens, seq_len=8192):

8B single-node

MXFP8 + qinit (+31.1%) and MXFP8 without qinit (+30.4%) deliver essentially the same throughput gain — at 32 layers the per-layer quantize/dequantize saving is small.

Llama-3.1-70B (1 node / 8× B300 SXM6 AC, mbs=1, cp=2, dp=4, gbs=4 seqs, seq_len=8192):

70B single-node

MXFP8 + qinit (+39.4%) pulls ahead of MXFP8 without qinit (+28.7%) — a ~10 pp gap that doesn't appear at 8B. With 80 transformer layers, avoiding per-step quantize/dequantize adds up. preserve_high_precision_init_val=True (HPIV) is within 1% of qinit-without-HPIV, so HPIV is essentially free at steady state.

Multi-node throughput (B200, production-scale runs)

Llama-3.1-8B (8 nodes / 64× B200, mbs=2, grad_acc=2, gbs=256, seq_len=8192):

8B multi-node

MXFP8 + qinit: 22,517 tokens/s/GPU vs 17,644 BF16 — +27.6% throughput (×1.28 speedup)

Llama-3.1-70B (4 nodes / 32× B200, cp=2, dp=16, mbs=1, gbs=16, seq_len=8192):

70B multi-node

MXFP8 + qinit: 2,725 tokens/s/GPU vs 1,972 BF16 — +38.2% throughput (×1.40 speedup)

Wandb run links

Test plan

  • All existing model-level tests pass (parametrized across DelayedScaling, Float8CurrentScaling, Float8BlockScaling, MXFP8BlockScaling with xfail for unsupported hardware)
  • All existing recipe-level tests pass (same parametrization pattern)
  • test_quantized_model_init.py — 4 tests × 4 recipes = 16 test cases (8 pass on L4, 8 xfail for Hopper/Blackwell-only recipes)
  • check_copied_files.py passes — all 3 modeling_llama_te.py copies are identical
  • Pre-commit hooks pass
  • Single-node MXFP8 training verified on Blackwell (benchmarked, see above)
  • Multi-node training verified on B200 cluster (benchmarked, see above)

Type of changes

  • New feature (non-breaking change which adds functionality)

CI Pipeline Configuration

Note

By default, only basic unit tests are run. Add appropriate labels to enable additional test coverage.

…nd 70B support

- Add quantization.py with layer-wise precision (MXFP8/NVFP4/BF16 per layer)
- Add quantized_model_init support with FusedAdam FP32 master weights
- Add stateful TokenPackingDataset for checkpoint resume across all collators
- Add 70B Llama configs with context parallelism and THD format
- Add checkpoint.py with FSDP2 save/load utilities
- Update modeling_llama_te.py with per-layer FP8 recipe injection
- Update train_fsdp2.py and train_fsdp2_cp.py for quantized training
- Add comprehensive tests for quantization and checkpoint resume

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 12, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 65cb6d83-fffe-48cf-930e-dc23672d735c

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch savitha/lingua-7b-fp8-clean-pr

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

…emove dead code

Simplify the config parameter name since FusedAdam is now the only FP32 master
weights strategy. Also remove a stale guard in train_ddp.py and a duplicate
config line in L2_lingua_7b_pure_bf16.yaml.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test 3b71e25

1 similar comment
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test 3b71e25

Adding state_dict/load_state_dict causes StatefulDataLoader to switch
from fast-forward replay to stateful restore, which produces incorrect
batches because the packing generator state is not serializable.
Fast-forward replay works correctly with the deterministic shuffle seed.

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test f68ff4e

…estore pytest

- Rename L2_lingua_7b → L2_lingua_8b (model is Llama-3.1-8B)
- Rename L2_lingua_7b_mxfp8_qinit → L2_lingua_8b_mxfp8_qinit
- Keep only 4 example configs: 8b base, 8b mxfp8+qinit, 70b base, 70b mxfp8+qinit
- Remove 11 experiment-specific configs (bf16_baseline, fp8, mxfp8, thd, cp4, etc.)
- Restore pytest to requirements.txt (needed by CI runner)

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test a611fb4

- Restore is_compileable property on HFInferenceParams (accidentally
  dropped from PR 1500), required by newer transformers generate().
- Unify get_autocast_context init path to work both standalone (model
  tests, no outer context) and with outer quantized_model_init (recipe
  training). FP8/FP4 layers use per-layer quantized_model_init with
  preserve_high_precision_init_val=True; BF16 layers use
  quantized_model_init(enabled=False) to override any outer context.

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test 0bfaceb

…h xfail

Match the pattern used by model-level tests and conftest.py: parametrize
across DelayedScaling, Float8CurrentScaling, Float8BlockScaling, and
MXFP8BlockScaling with automatic xfail for unsupported hardware.
Previously hardcoded Float8BlockScaling which requires sm90+ (Hopper)
but CI runs on L4 (sm89).

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
@savitha-eng savitha-eng force-pushed the savitha/lingua-7b-fp8-clean-pr branch from 71a24ee to 0134793 Compare May 14, 2026 05:40
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test 0134793

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
Recover llama3_8gpu_tflops.png from PR #1500 branch and copy
lingua-1b-loss-curve.png and lingua-1b-step-time.png from
images/recipes/ to images/llama3/ to match README paths.

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
- Remove _log_per_layer_gradient_norms and gradient_debug config (debug-only)
- Remove _log_memory helper and all call sites
- Remove no-op MixedPrecisionPolicy (FP32 master weights handled by FusedAdam)
- Remove ValueError catch for StatefulDataLoader fast-forward workaround

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test c4a8068

Add single-node and multi-node MXFP8 vs BF16 throughput comparisons
for Llama-3.1-8B and 70B with quantized model init analysis and W&B
run links.

Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
@savitha-eng
Copy link
Copy Markdown
Collaborator Author

/ok to test da14eaa

@savitha-eng savitha-eng marked this pull request as ready for review May 14, 2026 20:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant