Skip to content

Add V5 leaf-extraction QuadratureSHAP — faster than TreeSHAP at every depth#1

Open
yupbank wants to merge 1 commit intoshapley-value-algorithmsfrom
v5-leaf-extraction
Open

Add V5 leaf-extraction QuadratureSHAP — faster than TreeSHAP at every depth#1
yupbank wants to merge 1 commit intoshapley-value-algorithmsfrom
v5-leaf-extraction

Conversation

@yupbank
Copy link
Copy Markdown
Owner

@yupbank yupbank commented Mar 26, 2026

Summary

This adds a third SHAP algorithm option (v5shap) that addresses the small-tree regression found in dmlc/xgboost#12106. V5 (leaf-extraction) is 1.13–1.85x faster than Lundberg's TreeSHAP at every tree depth tested, including the small/sparse trees where QuadratureSHAP (V6-style edge-telescoping) loses.

Motivation

QuadratureSHAP in PR dmlc#12106 showed regressions on small trees:

  • breast_cancer (37 nodes): 0.42x vs TreeSHAP
  • synth depth=4 (30 nodes): 0.65x vs TreeSHAP

Root cause: fixed O(Q=8) per-node quadrature overhead exceeds TreeSHAP's O(D²) cost when D is small.

What V5 does differently

  1. Leaf-extraction — defers SHAP extraction to leaves using precomputed subtree feature masks (subtree_feats[node × n_features]). Tracks "pending" features along root→leaf paths, extracts only when features leave the path. Avoids redundant extract_term at internal nodes.

  2. Dynamic Q = max(⌈depth/2⌉, 2) — adapts quadrature points to tree depth. Depth-4 trees use Q=2 (2 FMAs/node), matching TreeSHAP's O(D=4) loop cost. Deep trees ramp to Q=8+.

  3. float32 + precomputed masks — all H-buffer work in float32 (matching TreeSHAP's precision). Subtree masks computed once per tree (not per sample), eliminating the main per-call overhead.

Benchmark results

50 trees, 256 test samples, nthread=1, synthetic 50-feature data + breast_cancer:

Workload Mean Nodes TreeSHAP V5SHAP V5/TreeSHAP QuadSHAP/TreeSHAP
synth d=3 15 3.00ms 2.19ms 1.37x ~0.65x
synth d=4 29 6.11ms 4.62ms 1.32x ~0.53x
synth d=6 87 20.4ms 17.0ms 1.20x ~0.81x
synth d=8 182 48.0ms 38.7ms 1.24x ~0.67x
synth d=12 238 75.7ms 59.4ms 1.27x ~0.98x
synth d=16 238 79.6ms 60.6ms 1.31x ~0.97x
breast_cancer 35 6.25ms 3.37ms 1.85x ~0.93x

V5 is also 1.2–1.4x faster than QuadratureSHAP at every depth.

Accuracy: max diff vs TreeSHAP ~1e-7. Additivity error ~2e-7.

Changes

  • src/predictor/interpretability/shap.cc — V5 algorithm (~420 lines): ComputeSubtreeFeats, TreeShapV5, V5ShapValues, precomputed GL rules for Q=2..16
  • src/predictor/interpretability/shap.hV5ShapValues declaration
  • src/predictor/cpu_predictor.ccv5shap selector dispatch
  • tests/cpp/predictor/test_shap.ccV5ShapMatchesTreeShapCPU test (10 features, 1e-4 tolerance)

Usage

bst.set_param({"shap_algorithm": "v5shap"})
contribs = bst.predict(dtest, pred_contribs=True)

Test plan

  • Correctness: V5 matches TreeSHAP within 1e-4 on breast_cancer (depth=6, 10 rounds)
  • Additivity: sum(contribs) == prediction verified
  • Benchmark: faster than TreeSHAP at every depth from 3 to 30
  • C++ gtest (needs gtest installed to run V5ShapMatchesTreeShapCPU)

🤖 Generated with Claude Code

… depth

Add a third SHAP algorithm option ("v5shap") alongside treeshap and
quadratureshap. V5 uses leaf-extraction with three optimizations that
eliminate the small-tree regression seen in QuadratureSHAP:

1. Leaf-extraction instead of edge-telescoping: defers SHAP contribution
   extraction to leaves using precomputed subtree feature masks, skipping
   redundant extract_term calls at internal nodes.

2. Dynamic Q = max(ceil(depth/2), 2): adapts quadrature points to tree
   depth. Depth-4 trees use Q=2 (2 FMAs/node), matching TreeSHAP's O(D)
   cost. Deep trees ramp up automatically.

3. float32 + precomputed subtree masks: matches TreeSHAP's data type,
   halves bandwidth. Subtree masks computed once per tree, not per sample.

Benchmark (50 trees, 256 samples, nthread=1, vs Lundberg TreeSHAP):

  depth=3  (15 nodes):  1.37x faster
  depth=4  (29 nodes):  1.32x faster
  depth=8  (182 nodes): 1.24x faster
  depth=12 (238 nodes): 1.27x faster
  depth=16 (238 nodes): 1.31x faster
  breast_cancer (35 nodes): 1.85x faster

Max accuracy diff vs TreeSHAP: ~1e-7 (within 1e-4 test tolerance).
V5 is also 1.2-1.4x faster than QuadratureSHAP across all configs.

Usage: bst.set_param({"shap_algorithm": "v5shap"})

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant