From 57590dc95984946908a4e2fbef6e450f116a08fc Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 3 Feb 2026 03:50:28 -0800 Subject: [PATCH 01/17] Refactor tests --- src/interpretability/shap.cc | 415 +++++++++++++++++++++++++++++++ src/interpretability/shap.h | 29 +++ src/predictor/cpu_predictor.cc | 177 +------------ tests/cpp/predictor/test_shap.cc | 73 +++++- 4 files changed, 518 insertions(+), 176 deletions(-) create mode 100644 src/interpretability/shap.cc create mode 100644 src/interpretability/shap.h diff --git a/src/interpretability/shap.cc b/src/interpretability/shap.cc new file mode 100644 index 000000000000..5b23543703e6 --- /dev/null +++ b/src/interpretability/shap.cc @@ -0,0 +1,415 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#include "shap.h" + +#include // for fill +#include // for numeric_limits +#include // for remove_const_t +#include // for vector + +#include "../common/categorical.h" // for IsCat +#include "../common/column_matrix.h" // for ColumnMatrix +#include "../common/hist_util.h" // for DispatchBinType, HistogramCuts +#include "../common/math.h" // for CheckNAN +#include "../common/threading_utils.h" // for ParallelFor +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "../gbm/gbtree_model.h" // for GBTreeModel +#include "../predictor/predict_fn.h" // for GetTreeLimit +#include "../predictor/treeshap.h" // for CalculateContributions +#include "../tree/tree_view.h" // for ScalarTreeView +#include "dmlc/omp.h" // for omp_get_thread_num +#include "xgboost/base.h" // for bst_omp_uint +#include "xgboost/logging.h" // for CHECK +#include "xgboost/multi_target_tree_model.h" // for MTNotImplemented + +namespace xgboost::interpretability { +namespace { +float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx, + std::vector *mean_values) { + float result; + auto &node_mean_values = *mean_values; + if (tree.IsLeaf(nidx)) { + result = tree.LeafValue(nidx); + } else { + result = FillNodeMeanValues(tree, tree.LeftChild(nidx), mean_values) * + tree.Stat(tree.LeftChild(nidx)).sum_hess; + result += FillNodeMeanValues(tree, tree.RightChild(nidx), mean_values) * + tree.Stat(tree.RightChild(nidx)).sum_hess; + result /= tree.Stat(nidx).sum_hess; + } + node_mean_values[nidx] = result; + return result; +} + +void FillNodeMeanValues(tree::ScalarTreeView const &tree, std::vector *mean_values) { + auto n_nodes = tree.Size(); + if (static_cast(mean_values->size()) == n_nodes) { + return; + } + mean_values->resize(n_nodes); + FillNodeMeanValues(tree, 0, mean_values); +} + +void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVec const &feats, + std::vector *mean_values, + std::vector *out_contribs) { + CHECK_EQ(out_contribs->size(), feats.Size() + 1); + CalculateContributionsApprox(tree, feats, mean_values, out_contribs->data()); +} + +class GHistIndexMatrixView { + private: + GHistIndexMatrix const &page_; + common::Span ft_; + + std::vector const &ptrs_; + std::vector const &mins_; + std::vector const &values_; + common::ColumnMatrix const &columns_; + + public: + bst_idx_t const base_rowid; + + public: + GHistIndexMatrixView(GHistIndexMatrix const &page, common::Span ft) + : page_{page}, + ft_{ft}, + ptrs_{page.cut.Ptrs()}, + mins_{page.cut.MinValues()}, + values_{page.cut.Values()}, + columns_{page.Transpose()}, + base_rowid{page.base_rowid} {} + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { + auto gridx = ridx + this->base_rowid; + auto n_features = page_.Features(); + + bst_idx_t n_non_missings = 0; + if (page_.IsDense()) { + common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { + using T = decltype(t); + auto ptr = this->page_.index.template data(); + auto rbeg = this->page_.row_ptr[ridx]; + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + bst_bin_t bin_idx; + float fvalue; + if (common::IsCat(ft_, fidx)) { + bin_idx = page_.GetGindex(gridx, fidx); + fvalue = this->values_[bin_idx]; + } else { + bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; + fvalue = + common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); + } + out[fidx] = fvalue; + } + }); + n_non_missings += n_features; + } else { + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + float fvalue = std::numeric_limits::quiet_NaN(); + bool is_cat = common::IsCat(ft_, fidx); + if (columns_.GetColumnType(fidx) == common::kSparseColumn) { + auto bin_idx = page_.GetGindex(gridx, fidx); + if (bin_idx != -1) { + if (is_cat) { + fvalue = values_[bin_idx]; + } else { + fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, + bin_idx); + } + } + } else { + fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); + } + if (!common::CheckNAN(fvalue)) { + out[fidx] = fvalue; + n_non_missings++; + } + } + } + return n_non_missings; + } + + [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } +}; +} // namespace + +void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, unsigned condition_feature) { + CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + // number of valid trees + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + size_t const ncolumns = model.learner_model_param->num_feature + 1; + // allocate space for (number of features + bias) times the number of rows + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); + // make sure contributions is zeroed, we could be reusing a previously allocated one + std::fill(contribs.begin(), contribs.end(), 0); + // initialize tree node mean values + std::vector> mean_values(tree_end); + for (bst_omp_uint i = 0; i < tree_end; ++i) { + FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); + } + + auto const n_groups = model.learner_model_param->num_output_group; + CHECK_NE(n_groups, 0); + auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); + auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + auto const n_threads = ctx->Threads(); + std::vector feats_tloc(n_threads); + std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); + + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto base_margin = info.base_margin_.View(device); + + if (p_fmat->PageExists()) { + for (auto const &page : p_fmat->GetBatches()) { + auto view = page.GetView(); + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = page.base_rowid + i; + feats.Fill(view[i]); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), + condition, condition_feature); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + } + } else { + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &page : p_fmat->GetBatches(ctx, {})) { + GHistIndexMatrixView view{page, ft}; + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), + condition, condition_feature); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + } + } +} + +void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights) { + CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + size_t const ncolumns = model.learner_model_param->num_feature + 1; + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); + std::fill(contribs.begin(), contribs.end(), 0); + std::vector> mean_values(tree_end); + for (bst_omp_uint i = 0; i < tree_end; ++i) { + FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); + } + + auto const n_groups = model.learner_model_param->num_output_group; + CHECK_NE(n_groups, 0); + auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); + auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + auto const n_threads = ctx->Threads(); + std::vector feats_tloc(n_threads); + std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); + + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto base_margin = info.base_margin_.View(device); + + if (p_fmat->PageExists()) { + for (auto const &page : p_fmat->GetBatches()) { + auto view = page.GetView(); + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = page.base_rowid + i; + feats.Fill(view[i]); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + } + } else { + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &page : p_fmat->GetBatches(ctx, {})) { + GHistIndexMatrixView view{page, ft}; + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + } + } +} + +void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate) { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "Predict interaction contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " + "column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + auto const ngroup = model.learner_model_param->num_output_group; + auto const ncolumns = model.learner_model_param->num_feature; + const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); + const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1); + const unsigned crow_chunk = ngroup * (ncolumns + 1); + + // allocate space for (number of features^2) times the number of rows and tmp off/on contribs + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1)); + HostDeviceVector contribs_off_hdv(info.num_row_ * ngroup * (ncolumns + 1)); + auto &contribs_off = contribs_off_hdv.HostVector(); + HostDeviceVector contribs_on_hdv(info.num_row_ * ngroup * (ncolumns + 1)); + auto &contribs_on = contribs_on_hdv.HostVector(); + HostDeviceVector contribs_diag_hdv(info.num_row_ * ngroup * (ncolumns + 1)); + auto &contribs_diag = contribs_diag_hdv.HostVector(); + + // Compute the difference in effects when conditioning on each of the features on and off + // see: Axiomatic characterizations of probabilistic and + // cardinal-probabilistic interaction indices + if (approximate) { + ApproxFeatureImportance(ctx, p_fmat, &contribs_diag_hdv, model, tree_end, tree_weights); + } else { + ShapValues(ctx, p_fmat, &contribs_diag_hdv, model, tree_end, tree_weights, 0, 0); + } + for (size_t i = 0; i < ncolumns + 1; ++i) { + if (approximate) { + ApproxFeatureImportance(ctx, p_fmat, &contribs_off_hdv, model, tree_end, tree_weights); + ApproxFeatureImportance(ctx, p_fmat, &contribs_on_hdv, model, tree_end, tree_weights); + } else { + ShapValues(ctx, p_fmat, &contribs_off_hdv, model, tree_end, tree_weights, -1, i); + ShapValues(ctx, p_fmat, &contribs_on_hdv, model, tree_end, tree_weights, 1, i); + } + + for (size_t j = 0; j < info.num_row_; ++j) { + for (std::remove_const_t l = 0; l < ngroup; ++l) { + const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1); + const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1); + contribs[o_offset + i] = 0; + for (size_t k = 0; k < ncolumns + 1; ++k) { + // fill in the diagonal with additive effects, and off-diagonal with the interactions + if (k == i) { + contribs[o_offset + i] += contribs_diag[c_offset + k]; + } else { + contribs[o_offset + k] = (contribs_on[c_offset + k] - contribs_off[c_offset + k]) / 2.0; + contribs[o_offset + i] -= contribs[o_offset + k]; + } + } + } + } + } +} +} // namespace xgboost::interpretability diff --git a/src/interpretability/shap.h b/src/interpretability/shap.h new file mode 100644 index 000000000000..aa842dd14ca6 --- /dev/null +++ b/src/interpretability/shap.h @@ -0,0 +1,29 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#pragma once + +#include // for vector + +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost::gbm { +struct GBTreeModel; +} // namespace xgboost::gbm + +namespace xgboost::interpretability { +void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, unsigned condition_feature); + +void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights); + +void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate); +} // namespace xgboost::interpretability diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 7dfa03bbbec1..45387542a82a 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -20,11 +20,11 @@ #include "../data/gradient_index.h" // for GHistIndexMatrix #include "../data/proxy_dmatrix.h" // for DMatrixProxy #include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam +#include "../interpretability/shap.h" // for PredictContribution #include "array_tree_layout.h" // for ProcessArrayTree #include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG #include "gbtree_view.h" // for GBTreeModelView #include "predict_fn.h" // for GetNextNode, GetNextNodeMulti -#include "treeshap.h" // for CalculateContributions #include "utils.h" // for CheckProxyDMatrix #include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... #include "xgboost/context.h" // for Context @@ -572,31 +572,6 @@ void PredictBatchByBlockKernel(DataView const &batch, HostModel const &model, }); } -float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx, - std::vector *mean_values) { - float result; - auto &node_mean_values = *mean_values; - if (tree.IsLeaf(nidx)) { - result = tree.LeafValue(nidx); - } else { - result = FillNodeMeanValues(tree, tree.LeftChild(nidx), mean_values) * - tree.Stat(tree.LeftChild(nidx)).sum_hess; - result += FillNodeMeanValues(tree, tree.RightChild(nidx), mean_values) * - tree.Stat(tree.RightChild(nidx)).sum_hess; - result /= tree.Stat(nidx).sum_hess; - } - node_mean_values[nidx] = result; - return result; -} - -void FillNodeMeanValues(tree::ScalarTreeView const &tree, std::vector *mean_values) { - auto n_nodes = tree.Size(); - if (static_cast(mean_values->size()) == n_nodes) { - return; - } - mean_values->resize(n_nodes); - FillNodeMeanValues(tree, 0, mean_values); -} } // anonymous namespace /** @@ -916,65 +891,6 @@ class CPUPredictor : public Predictor { }); } - template - void PredictContributionKernel(DataView batch, const MetaInfo &info, HostModel const &h_model, - linalg::VectorView base_score, - std::vector const *tree_weights, - std::vector> *mean_values, - ThreadTmp<1> *feat_vecs, std::vector *contribs, - bool approximate, int condition, - unsigned condition_feature) const { - const int num_feature = h_model.n_features; - const auto n_groups = h_model.n_groups; - CHECK_NE(n_groups, 0); - size_t const ncolumns = num_feature + 1; - CHECK_NE(ncolumns, 0); - auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device(); - auto base_margin = info.base_margin_.View(device); - - // parallel over local batch - common::ParallelFor(batch.Size(), this->ctx_->Threads(), [&](auto i) { - auto row_idx = batch.base_rowid + i; - RegTree::FVec &feats = feat_vecs->ThreadBuffer(1).front(); - if (feats.Size() == 0) { - feats.Init(num_feature); - } - std::vector this_tree_contribs(ncolumns); - // loop over all classes - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &(*contribs)[(row_idx * n_groups + gid) * ncolumns]; - batch.Fill(i, &feats); - // calculate contributions - for (bst_tree_t j = 0; j < h_model.tree_end; ++j) { - auto *tree_mean_values = &mean_values->at(j); - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - if (h_model.tree_groups[j] != gid) { - continue; - } - auto sc_tree = std::get(h_model.Trees()[j]); - if (!approximate) { - CalculateContributions(sc_tree, feats, tree_mean_values, &this_tree_contribs[0], - condition, condition_feature); - } else { - CalculateContributionsApprox(sc_tree, feats, tree_mean_values, &this_tree_contribs[0]); - } - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); - } - } - feats.Drop(); - // add base margin to BIAS - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } - } - }); - } - public: explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} @@ -1086,94 +1002,21 @@ class CPUPredictor : public Predictor { const gbm::GBTreeModel &model, bst_tree_t ntree_limit, std::vector const *tree_weights, bool approximate, int condition, unsigned condition_feature) const override { - CHECK(!model.learner_model_param->IsVectorLeaf()) - << "Predict contribution" << MTNotImplemented(); - CHECK(!p_fmat->Info().IsColumnSplit()) - << "Predict contribution support for column-wise data split is not yet implemented."; - auto const n_threads = this->ctx_->Threads(); - ThreadTmp<1> feat_vecs{n_threads}; - const MetaInfo &info = p_fmat->Info(); - // number of valid trees - ntree_limit = GetTreeLimit(model.trees, ntree_limit); - size_t const ncolumns = model.learner_model_param->num_feature + 1; - // allocate space for (number of features + bias) times the number of rows - std::vector &contribs = out_contribs->HostVector(); - contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); - // make sure contributions is zeroed, we could be reusing a previously - // allocated one - std::fill(contribs.begin(), contribs.end(), 0); - // initialize tree node mean values - std::vector> mean_values(ntree_limit); - common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { - FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); - }); - - auto const h_model = - HostModel{DeviceOrd::CPU(), model, true, 0, ntree_limit, &this->mu_, CopyViews{}}; - LaunchPredict(this->ctx_, p_fmat, model, [&](auto &&policy) { - policy.ForEachBatch([&](auto &&batch) { - PredictContributionKernel(batch, info, h_model, - model.learner_model_param->BaseScore(DeviceOrd::CPU()), - tree_weights, &mean_values, &feat_vecs, &contribs, approximate, - condition, condition_feature); - }); - }); + if (approximate) { + interpretability::ApproxFeatureImportance(this->ctx_, p_fmat, out_contribs, model, + ntree_limit, tree_weights); + } else { + interpretability::ShapValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, + tree_weights, condition, condition_feature); + } } void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t ntree_limit, std::vector const *tree_weights, bool approximate) const override { - CHECK(!model.learner_model_param->IsVectorLeaf()) - << "Predict interaction contribution" << MTNotImplemented(); - CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " - "column-wise data split is not yet implemented."; - const MetaInfo &info = p_fmat->Info(); - auto const ngroup = model.learner_model_param->num_output_group; - auto const ncolumns = model.learner_model_param->num_feature; - const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); - const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1); - const unsigned crow_chunk = ngroup * (ncolumns + 1); - - // allocate space for (number of features^2) times the number of rows and tmp off/on contribs - std::vector &contribs = out_contribs->HostVector(); - contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1)); - HostDeviceVector contribs_off_hdv(info.num_row_ * ngroup * (ncolumns + 1)); - auto &contribs_off = contribs_off_hdv.HostVector(); - HostDeviceVector contribs_on_hdv(info.num_row_ * ngroup * (ncolumns + 1)); - auto &contribs_on = contribs_on_hdv.HostVector(); - HostDeviceVector contribs_diag_hdv(info.num_row_ * ngroup * (ncolumns + 1)); - auto &contribs_diag = contribs_diag_hdv.HostVector(); - - // Compute the difference in effects when conditioning on each of the features on and off - // see: Axiomatic characterizations of probabilistic and - // cardinal-probabilistic interaction indices - PredictContribution(p_fmat, &contribs_diag_hdv, model, ntree_limit, tree_weights, approximate, - 0, 0); - for (size_t i = 0; i < ncolumns + 1; ++i) { - PredictContribution(p_fmat, &contribs_off_hdv, model, ntree_limit, tree_weights, approximate, - -1, i); - PredictContribution(p_fmat, &contribs_on_hdv, model, ntree_limit, tree_weights, approximate, - 1, i); - - for (size_t j = 0; j < info.num_row_; ++j) { - for (std::remove_const_t l = 0; l < ngroup; ++l) { - const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1); - const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1); - contribs[o_offset + i] = 0; - for (size_t k = 0; k < ncolumns + 1; ++k) { - // fill in the diagonal with additive effects, and off-diagonal with the interactions - if (k == i) { - contribs[o_offset + i] += contribs_diag[c_offset + k]; - } else { - contribs[o_offset + k] = - (contribs_on[c_offset + k] - contribs_off[c_offset + k]) / 2.0; - contribs[o_offset + i] -= contribs[o_offset + k]; - } - } - } - } - } + interpretability::ShapInteractionValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, + tree_weights, approximate); } private: diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index 9368ec1ee0d8..a13f88b3c6a1 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -7,11 +7,18 @@ #include #include // for DMatrix #include // for HostDeviceVector +#include // for Json #include // for Learner +#include // for Vector +#include #include // for unique_ptr +#include #include // for to_string +#include "../../../src/common/param_array.h" +#include "../../../src/gbm/gbtree_model.h" +#include "../../../src/interpretability/shap.h" #include "../helpers.h" namespace xgboost { @@ -42,6 +49,49 @@ Args BaseParams(Context const* ctx, std::string objective, std::string max_depth {"colsample_bytree", "1"}, {"device", ctx->IsSycl() ? "cpu" : ctx->DeviceName()}}; } + +gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, + LearnerModelParam* out_param) { + Json model{Object{}}; + learner->SaveModel(&model); + + CHECK(IsA(model)) << model; + auto const& model_obj = get(model); + auto learner_it = model_obj.find("learner"); + CHECK(learner_it != model_obj.cend()) << model; + CHECK(IsA(learner_it->second)) << model; + auto const& learner_obj = get(learner_it->second); + + auto const& lmp = get(learner_obj.at("learner_model_param")); + auto const& num_feature = get(lmp.at("num_feature")); + auto const& num_class = get(lmp.at("num_class")); + auto const& num_target = get(lmp.at("num_target")); + auto const& base_score_str = get(lmp.at("base_score")); + + common::ParamArray base_score_arr{"base_score"}; + std::stringstream ss; + ss << base_score_str; + ss >> base_score_arr; + + std::size_t shape[1]{base_score_arr.size()}; + linalg::Vector base_score_vec{shape, ctx->Device()}; + auto& h_base = base_score_vec.Data()->HostVector(); + h_base.assign(base_score_arr.cbegin(), base_score_arr.cend()); + + auto n_features = static_cast(std::stol(num_feature)); + auto n_classes = static_cast(std::stol(num_class)); + auto n_targets = static_cast(std::stol(num_target)); + auto n_groups = static_cast(std::max(n_classes, n_targets)); + *out_param = LearnerModelParam{n_features, std::move(base_score_vec), n_groups, n_targets, + MultiStrategy::kOneOutputPerTree}; + + gbm::GBTreeModel gbtree{out_param, ctx}; + auto gbm_it = learner_obj.find("gradient_booster"); + CHECK(gbm_it != learner_obj.cend()) << model; + CHECK(IsA(gbm_it->second)) << model; + gbtree.LoadModel(gbm_it->second); + return gbtree; +} } // namespace std::vector BuildShapTestCases(Context const* ctx) { @@ -99,10 +149,8 @@ std::vector BuildShapTestCases(Context const* ctx) { { // multi-class dense training DMatrix, medium depth bst_target_t n_classes{3}; - auto dmat = RandomDataGenerator(256, 8, 0.0) - .Classes(n_classes) - .Device(device) - .GenerateDMatrix(true); + auto dmat = + RandomDataGenerator(256, 8, 0.0).Classes(n_classes).Device(device).GenerateDMatrix(true); SetLabels(dmat.get(), n_classes); auto args = BaseParams(ctx, "multi:softprob", "4"); args.emplace_back("num_class", std::to_string(n_classes)); @@ -135,15 +183,18 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { learner->Predict(p_dmat, true, &margin_predt, 0, 0, false, false, false, false, false); size_t const n_outputs = margin_predt.HostVector().size() / kRows; + LearnerModelParam mparam; + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), &mparam); + HostDeviceVector shap_values; - learner->Predict(p_dmat, false, &shap_values, 0, 0, false, false, true, false, false); + interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, gbtree, 0, nullptr, 0, 0); ASSERT_EQ(shap_values.HostVector().size(), kRows * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_values, margin_predt); HostDeviceVector shap_interactions; - learner->Predict(p_dmat, false, &shap_interactions, 0, 0, false, false, false, false, true); - ASSERT_EQ(shap_interactions.HostVector().size(), - kRows * (kCols + 1) * (kCols + 1) * n_outputs); + interpretability::ShapInteractionValues(dmat->Ctx(), p_dmat.get(), &shap_interactions, gbtree, 0, + nullptr, false); + ASSERT_EQ(shap_interactions.HostVector().size(), kRows * (kCols + 1) * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_interactions, margin_predt); } @@ -207,8 +258,12 @@ TEST(Predictor, ApproxContribsBasic) { HostDeviceVector margin_predt; learner->Predict(dmat, true, &margin_predt, 0, 0, false, false, false, false, false); + LearnerModelParam mparam; + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), &mparam); + HostDeviceVector approx_contribs; - learner->Predict(dmat, false, &approx_contribs, 0, 0, false, false, true, true, false); + interpretability::ApproxFeatureImportance(dmat->Ctx(), dmat.get(), &approx_contribs, gbtree, 0, + nullptr); auto const& h_margin = margin_predt.ConstHostVector(); auto const& h_contribs = approx_contribs.ConstHostVector(); From 5ef9c2e0608e13e743d383b91c90711a4f4b7219 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 4 Feb 2026 01:38:18 -0800 Subject: [PATCH 02/17] Refactor SHAP dispatch and GPU implementation --- src/interpretability/shap.cc | 22 +- src/interpretability/shap.cu | 537 +++++++++++++++++++++++++++++++ src/interpretability/shap.h | 83 ++++- src/predictor/gpu_predictor.cu | 156 ++------- tests/cpp/predictor/test_shap.cc | 44 ++- 5 files changed, 676 insertions(+), 166 deletions(-) create mode 100644 src/interpretability/shap.cu diff --git a/src/interpretability/shap.cc b/src/interpretability/shap.cc index 5b23543703e6..f2dee5eb8a48 100644 --- a/src/interpretability/shap.cc +++ b/src/interpretability/shap.cc @@ -136,9 +136,10 @@ class GHistIndexMatrixView { }; } // namespace -void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature) { +void ShapValuesCPU(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, + unsigned condition_feature) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; @@ -247,9 +248,10 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou } } -void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights) { +void ApproxFeatureImportanceCPU(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; @@ -352,10 +354,10 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, } } -void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, - bool approximate) { +void ShapInteractionValuesCPU(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " diff --git a/src/interpretability/shap.cu b/src/interpretability/shap.cu new file mode 100644 index 000000000000..1e295853224e --- /dev/null +++ b/src/interpretability/shap.cu @@ -0,0 +1,537 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include // for proclaim_return_type +#include // for swap +#include +#include +#include +#include +#include +#include +#include + +#include "../common/categorical.h" +#include "../common/common.h" +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/cuda_rt_utils.h" // for SetDevice +#include "../common/device_helpers.cuh" +#include "../common/error_msg.h" +#include "../common/nvtx_utils.h" +#include "../data/batch_utils.h" // for StaticBatch +#include "../data/cat_container.cuh" // for EncPolicy, MakeCatAccessor +#include "../data/cat_container.h" // for NoOpAccessor +#include "../data/ellpack_page.cuh" +#include "../gbm/gbtree_model.h" +#include "../predictor/gbtree_view.h" +#include "../predictor/predict_fn.h" // for GetTreeLimit +#include "../tree/tree_view.h" +#include "shap.h" +#include "xgboost/data.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/linalg.h" // for UnravelIndex +#include "xgboost/logging.h" +#include "xgboost/multi_target_tree_model.h" // for MTNotImplemented + +namespace xgboost::interpretability::cuda_impl { +namespace { +using cuda_impl::StaticBatch; +using predictor::GBTreeModelView; + +struct SparsePageView { + common::Span d_data; + common::Span d_row_ptr; + bst_feature_t num_features; + + SparsePageView() = default; + explicit SparsePageView(Context const* ctx, SparsePage const& page, bst_feature_t n_features) + : d_data{[&] { + page.data.SetDevice(ctx->Device()); + return page.data.ConstDeviceSpan(); + }()}, + d_row_ptr{[&] { + page.offset.SetDevice(ctx->Device()); + return page.offset.ConstDeviceSpan(); + }()}, + num_features{n_features} {} + + [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { + auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; + auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; + if (end_ptr - begin_ptr == this->NumCols()) { + return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; + } + common::Span::iterator previous_middle; + while (end_ptr != begin_ptr) { + auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; + if (middle == previous_middle) { + break; + } else { + previous_middle = middle; + } + + if (middle->index == fidx) { + return middle->fvalue; + } else if (middle->index < fidx) { + begin_ptr = middle; + } else { + end_ptr = middle; + } + } + return std::numeric_limits::quiet_NaN(); + } + [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } + [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } +}; + +template +struct ShapSparsePageLoader { + public: + using SupportShmemLoad = std::false_type; + + SparsePageView data; + EncAccessor acc; + + template + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { + auto fvalue = data.GetElement(ridx, fidx); + return acc(fvalue, fidx); + } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } +}; + +template +struct EllpackLoader { + public: + using SupportShmemLoad = std::false_type; + + private: + EncAccessor acc_; + Accessor data_; + + public: + bool use_shared{false}; + bst_feature_t num_features; + bst_idx_t num_rows; + float missing; + + XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t n_features, + bst_idx_t num_rows, float missing, EncAccessor&& acc) + : acc_{std::forward(acc)}, + data_{m}, + num_features{n_features}, + num_rows{num_rows}, + missing{missing} {} + + template + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { + auto value = data_.GetElement(ridx, fidx); + return acc_(value, fidx); + } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return num_rows; } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return num_features; } +}; + +using TreeViewVar = cuda::std::variant; + +struct CopyViews { + Context const* ctx; + explicit CopyViews(Context const* ctx) : ctx{ctx} {} + + void operator()(dh::DeviceUVector* p_dst, std::vector&& src) { + xgboost_NVTX_FN_RANGE(); + p_dst->resize(src.size()); + auto d_dst = dh::ToSpan(*p_dst); + dh::safe_cuda(cudaMemcpyAsync(d_dst.data(), src.data(), d_dst.size_bytes(), cudaMemcpyDefault, + ctx->CUDACtx()->Stream())); + } +}; + +using DeviceModel = GBTreeModelView; +std::mutex s_model_mu; + +struct ShapSplitCondition { + ShapSplitCondition() = default; + XGBOOST_DEVICE + ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, bool is_missing_branch, + common::CatBitField cats) + : feature_lower_bound(feature_lower_bound), + feature_upper_bound(feature_upper_bound), + is_missing_branch(is_missing_branch), + categories{std::move(cats)} { + assert(feature_lower_bound <= feature_upper_bound); + } + + float feature_lower_bound; + float feature_upper_bound; + common::CatBitField categories; + bool is_missing_branch; + + [[nodiscard]] XGBOOST_DEVICE bool EvaluateSplit(float x) const { + if (isnan(x)) { + return is_missing_branch; + } + if (categories.Capacity() != 0) { + auto cat = static_cast(x); + return categories.Check(cat); + } else { + return x >= feature_lower_bound && x < feature_upper_bound; + } + } + + XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l, + common::CatBitField r) { + if (l.Data() == r.Data()) { + return l; + } + if (l.Capacity() > r.Capacity()) { + cuda::std::swap(l, r); + } + for (size_t i = 0; i < r.Bits().size(); ++i) { + l.Bits()[i] &= r.Bits()[i]; + } + return l; + } + + XGBOOST_DEVICE void Merge(ShapSplitCondition other) { + if (categories.Capacity() != 0 || other.categories.Capacity() != 0) { + categories = Intersect(categories, other.categories); + } else { + feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); + feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + } + is_missing_branch = is_missing_branch && other.is_missing_branch; + } +}; + +struct PathInfo { + std::size_t length; + bst_node_t nidx; + bst_tree_t tree_idx; + + [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return nidx != -1; } +}; +static_assert(sizeof(PathInfo) == 16); + +auto MakeTreeSegments(Context const* ctx, bst_tree_t tree_begin, bst_tree_t tree_end, + gbm::GBTreeModel const& model) { + auto tree_segments = HostDeviceVector({}, ctx->Device()); + auto& h_tree_segments = tree_segments.HostVector(); + h_tree_segments.reserve((tree_end - tree_begin) + 1); + std::size_t sum = 0; + h_tree_segments.push_back(sum); + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + auto const& p_tree = model.trees.at(tree_idx); + CHECK(!p_tree->IsMultiTarget()) << " SHAP " << MTNotImplemented(); + sum += p_tree->Size(); + h_tree_segments.push_back(sum); + } + return tree_segments; +} + +void ExtractPaths(Context const* ctx, + dh::device_vector>* paths, + gbm::GBTreeModel const& h_model, DeviceModel const& d_model, + dh::device_vector* path_categories) { + curt::SetDevice(ctx->Ordinal()); + + dh::caching_device_vector info(d_model.n_nodes); + auto d_trees = d_model.Trees(); + auto tree_segments = MakeTreeSegments(ctx, d_model.tree_begin, d_model.tree_end, h_model); + CHECK_EQ(tree_segments.ConstHostVector().back(), d_model.n_nodes); + auto d_tree_segments = tree_segments.ConstDeviceSpan(); + + auto path_it = dh::MakeIndexTransformIter( + cuda::proclaim_return_type([=] __device__(size_t idx) -> PathInfo { + bst_tree_t const tree_idx = dh::SegmentId(d_tree_segments, idx); + bst_node_t const nidx = idx - d_tree_segments[tree_idx]; + auto const& tree = cuda::std::get(d_trees[tree_idx]); + if (!tree.IsLeaf(nidx) || tree.IsDeleted(nidx)) { + return PathInfo{0, -1, 0}; + } + std::size_t path_length = 1; + auto iter_nidx = nidx; + while (!tree.IsRoot(iter_nidx)) { + iter_nidx = tree.Parent(iter_nidx); + path_length++; + } + return PathInfo{path_length, nidx, tree_idx}; + })); + auto end = thrust::copy_if( + ctx->CUDACtx()->CTP(), path_it, path_it + d_model.n_nodes, info.begin(), + cuda::proclaim_return_type([=] __device__(PathInfo const& e) { return e.IsLeaf(); })); + + info.resize(end - info.begin()); + using LenT = decltype(std::declval().length); + auto length_iterator = dh::MakeTransformIterator( + info.begin(), cuda::proclaim_return_type( + [=] __device__(PathInfo const& info) { return info.length; })); + dh::caching_device_vector path_segments(info.size() + 1); + thrust::exclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size() + 1, + path_segments.begin()); + + paths->resize(path_segments.back()); + + auto d_paths = dh::ToSpan(*paths); + auto d_info = info.data().get(); + auto d_tree_groups = d_model.tree_groups; + auto d_path_segments = path_segments.data().get(); + + std::size_t max_cat = 0; + if (std::any_of(h_model.trees.cbegin(), h_model.trees.cend(), + [](auto const& p_tree) { return p_tree->HasCategoricalSplit(); })) { + auto max_elem_it = dh::MakeIndexTransformIter([=] __device__(std::size_t i) -> std::size_t { + auto tree_idx = dh::SegmentId(d_tree_segments, i); + auto nidx = i - d_tree_segments[tree_idx]; + return cuda::std::get(d_trees[tree_idx]) + .GetCategoriesMatrix() + .node_ptr[nidx] + .size; + }); + auto max_cat_it = + thrust::max_element(ctx->CUDACtx()->CTP(), max_elem_it, max_elem_it + d_model.n_nodes); + dh::CachingDeviceUVector d_max_cat(1); + auto s_max_cat = dh::ToSpan(d_max_cat); + dh::LaunchN(1, ctx->CUDACtx()->Stream(), + [=] __device__(std::size_t) { s_max_cat[0] = *max_cat_it; }); + dh::safe_cuda( + cudaMemcpy(&max_cat, s_max_cat.data(), s_max_cat.size_bytes(), cudaMemcpyDeviceToHost)); + CHECK_GE(max_cat, 1); + path_categories->resize(max_cat * paths->size()); + } + + common::Span d_path_categories = dh::ToSpan(*path_categories); + + dh::LaunchN(info.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { + auto path_info = d_info[idx]; + auto tree = cuda::std::get(d_trees[path_info.tree_idx]); + std::int32_t group = d_tree_groups[path_info.tree_idx]; + auto child_nidx = path_info.nidx; + + float v = tree.LeafValue(child_nidx); + const float inf = std::numeric_limits::infinity(); + size_t output_position = d_path_segments[idx + 1] - 1; + + while (!tree.IsRoot(child_nidx)) { + auto parent_nidx = tree.Parent(child_nidx); + double child_cover = tree.SumHess(child_nidx); + double parent_cover = tree.SumHess(parent_nidx); + double zero_fraction = child_cover / parent_cover; + + bool is_left_path = tree.LeftChild(parent_nidx) == child_nidx; + bool is_missing_path = (!tree.DefaultLeft(parent_nidx) && !is_left_path) || + (tree.DefaultLeft(parent_nidx) && is_left_path); + + float lower_bound = -inf; + float upper_bound = inf; + common::CatBitField bits; + if (common::IsCat(tree.cats.split_type, tree.Parent(child_nidx))) { + auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat); + auto node_cats = tree.NodeCats(tree.Parent(child_nidx)); + SPAN_CHECK(path_cats.size() >= node_cats.size()); + for (size_t i = 0; i < node_cats.size(); ++i) { + path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i]; + } + bits = common::CatBitField{path_cats}; + } else { + lower_bound = is_left_path ? -inf : tree.SplitCond(parent_nidx); + upper_bound = is_left_path ? tree.SplitCond(parent_nidx) : inf; + } + d_paths[output_position--] = gpu_treeshap::PathElement{ + idx, tree.SplitIndex(parent_nidx), + group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits}, + zero_fraction, v}; + + child_nidx = parent_nidx; + } + d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v}; + }); +} + +template +class LaunchConfig { + public: + using EncAccessorT = EncAccessorT; + + explicit LaunchConfig(Context const* ctx, bst_feature_t n_features) + : ctx_{ctx}, n_features_{n_features} {} + + template + void ForEachBatch(DMatrix* p_fmat, EncAccessorT&& acc, Fn&& fn) { + if (p_fmat->PageExists()) { + for (auto& page : p_fmat->GetBatches()) { + SparsePageView batch{ctx_, page, n_features_}; + auto loader = ShapSparsePageLoader{batch, acc}; + fn(std::move(loader), page.base_rowid); + } + } else { + p_fmat->Info().feature_types.SetDevice(ctx_->Device()); + auto feature_types = p_fmat->Info().feature_types.ConstDeviceSpan(); + + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + page.Impl()->Visit(ctx_, feature_types, [&](auto&& batch) { + using Acc = std::remove_reference_t; + auto loader = EllpackLoader{batch, + /*use_shared=*/false, + this->n_features_, + batch.NumRows(), + std::numeric_limits::quiet_NaN(), + std::forward(acc)}; + fn(std::move(loader), batch.base_rowid); + }); + } + } + } + + private: + Context const* ctx_; + bst_feature_t n_features_; +}; + +template +void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, + gbm::GBTreeModel const& model, Kernel&& launch) { + if (model.Cats() && model.Cats()->HasCategorical() && new_enc.HasCategorical()) { + auto [acc, mapping] = cuda_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + auto cfg = + LaunchConfig{ctx, model.learner_model_param->num_feature}; + launch(std::move(cfg), std::move(acc)); + } else { + auto cfg = + LaunchConfig{ctx, model.learner_model_param->num_feature}; + launch(std::move(cfg), NoOpAccessor{}); + } +} +} // namespace + +void ShapValuesCUDA(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int, unsigned) { + xgboost_NVTX_FN_RANGE(); + StringView not_implemented{ + "contribution is not implemented in the GPU predictor, use CPU instead."}; + if (tree_weights != nullptr) { + LOG(FATAL) << "Dart booster feature " << not_implemented; + } + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); + out_contribs->SetDevice(ctx->Device()); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + + const int ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + size_t contributions_columns = model.learner_model_param->num_feature + 1; + auto dim_size = contributions_columns * model.learner_model_param->num_output_group; + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); + out_contribs->Fill(0.0f); + auto phis = out_contribs->DeviceSpan(); + + dh::device_vector> device_paths; + DeviceModel d_model{ctx->Device(), model, true, 0, tree_end, &s_model_mu, CopyViews{ctx}}; + + auto new_enc = + p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; + + dh::device_vector categories; + ExtractPaths(ctx, &device_paths, model, d_model, &categories); + + LaunchShap(ctx, new_enc, model, [&](auto&& cfg, auto&& acc) { + using Config = common::GetValueT; + using EncAccessor = typename Config::EncAccessorT; + + cfg.ForEachBatch( + p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { + auto begin = dh::tbegin(phis) + base_rowid * dim_size; + gpu_treeshap::GPUTreeShap>( + loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + }); + }); + + p_fmat->Info().base_margin_.SetDevice(ctx->Device()); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); + + auto base_score = model.learner_model_param->BaseScore(ctx); + bst_idx_t n_samples = p_fmat->Info().num_row_; + dh::LaunchN(n_samples * ngroup, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { + auto [_, gid] = linalg::UnravelIndex(idx, n_samples, ngroup); + phis[(idx + 1) * contributions_columns - 1] += margin.empty() ? base_score(gid) : margin[idx]; + }); +} + +void ShapInteractionValuesCUDA(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + bool approximate) { + xgboost_NVTX_FN_RANGE(); + std::string not_implemented{"contribution is not implemented in GPU predictor, use cpu instead."}; + if (approximate) { + LOG(FATAL) << "Approximated " << not_implemented; + } + if (tree_weights != nullptr) { + LOG(FATAL) << "Dart booster feature " << not_implemented; + } + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); + out_contribs->SetDevice(ctx->Device()); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + + const int ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + size_t contributions_columns = model.learner_model_param->num_feature + 1; + auto dim_size = + contributions_columns * contributions_columns * model.learner_model_param->num_output_group; + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); + out_contribs->Fill(0.0f); + auto phis = out_contribs->DeviceSpan(); + + dh::device_vector> device_paths; + DeviceModel d_model{ctx->Device(), model, true, 0, tree_end, &s_model_mu, CopyViews{ctx}}; + + dh::device_vector categories; + ExtractPaths(ctx, &device_paths, model, d_model, &categories); + auto new_enc = + p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; + + LaunchShap(ctx, new_enc, model, [&](auto&& cfg, auto&& acc) { + using Config = common::GetValueT; + using EncAccessor = typename Config::EncAccessorT; + + cfg.ForEachBatch( + p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { + auto begin = dh::tbegin(phis) + base_rowid * dim_size; + gpu_treeshap::GPUTreeShapInteractions>( + loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + }); + }); + + p_fmat->Info().base_margin_.SetDevice(ctx->Device()); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); + + auto base_score = model.learner_model_param->BaseScore(ctx); + size_t n_features = model.learner_model_param->num_feature; + bst_idx_t n_samples = p_fmat->Info().num_row_; + dh::LaunchN(n_samples * ngroup, ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { + auto [ridx, gidx] = linalg::UnravelIndex(idx, n_samples, ngroup); + phis[gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gidx, n_features, n_features, + n_features)] += + margin.empty() ? base_score(gidx) : margin[idx]; + }); +} + +void ApproxFeatureImportanceCUDA(Context const* ctx, DMatrix*, HostDeviceVector*, + gbm::GBTreeModel const&, bst_tree_t, std::vector const*) { + StringView not_implemented{ + "contribution is not implemented in the GPU predictor, use CPU instead."}; + LOG(FATAL) << "Approximated " << not_implemented; +} +} // namespace xgboost::interpretability::cuda_impl diff --git a/src/interpretability/shap.h b/src/interpretability/shap.h index aa842dd14ca6..43bf98a48c42 100644 --- a/src/interpretability/shap.h +++ b/src/interpretability/shap.h @@ -14,16 +14,75 @@ struct GBTreeModel; } // namespace xgboost::gbm namespace xgboost::interpretability { -void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature); - -void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights); - -void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, - bool approximate); +void ShapValuesCPU(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, + unsigned condition_feature); + +void ApproxFeatureImportanceCPU(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights); + +void ShapInteractionValuesCPU(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate); + +#if defined(XGBOOST_USE_CUDA) +void ShapValuesCUDA(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, + unsigned condition_feature); +void ApproxFeatureImportanceCUDA(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights); +void ShapInteractionValuesCUDA(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate); +#endif // defined(XGBOOST_USE_CUDA) + +inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, + unsigned condition_feature) { +#if defined(XGBOOST_USE_CUDA) + if (ctx->IsCUDA()) { + ShapValuesCUDA(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, + condition_feature); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + ShapValuesCPU(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, + condition_feature); +} + +inline void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights) { +#if defined(XGBOOST_USE_CUDA) + if (ctx->IsCUDA()) { + ApproxFeatureImportanceCUDA(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + ApproxFeatureImportanceCPU(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); +} + +inline void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, bool approximate) { +#if defined(XGBOOST_USE_CUDA) + if (ctx->IsCUDA()) { + ShapInteractionValuesCUDA(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, + approximate); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + ShapInteractionValuesCPU(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, approximate); +} } // namespace xgboost::interpretability diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index e6bbc3d0e650..370807fc8f10 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -26,6 +26,7 @@ #include "../data/proxy_dmatrix.cuh" // for DispatchAny #include "../data/proxy_dmatrix.h" #include "../gbm/gbtree_model.h" +#include "../interpretability/shap.h" #include "../tree/tree_view.h" #include "gbtree_view.h" // for GBTreeModelView #include "predict_fn.h" @@ -724,18 +725,23 @@ class ColumnSplitHelper { SparsePageView data{ctx_, batch, num_features}; auto const grid = static_cast(common::DivRoundUp(num_rows, kBlockThreads)); auto d_tree_groups = d_model.tree_groups; - dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()}( + // clang-format off + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, + ctx_->CUDACtx()->Stream()}( + // clang-format on MaskBitVectorKernel, data, d_model.Trees(), decision_bits, missing_bits, d_model.tree_begin, d_model.tree_end, num_features, num_nodes, use_shared, std::numeric_limits::quiet_NaN()); AllReduceBitVectors(&decision_storage, &missing_storage); - dh::LaunchKernel {grid, kBlockThreads, 0, ctx_->CUDACtx()->Stream()}( + // clang-format off + dh::LaunchKernel {grid, kBlockThreads, 0, + ctx_->CUDACtx()->Stream()}( + // clang-format on PredictByBitVectorKernel, d_model.Trees(), - out_preds->DeviceSpan().subspan(batch_offset), d_tree_groups, - decision_bits, missing_bits, d_model.tree_begin, d_model.tree_end, num_rows, num_nodes, - num_group); + out_preds->DeviceSpan().subspan(batch_offset), d_tree_groups, decision_bits, missing_bits, + d_model.tree_begin, d_model.tree_end, num_rows, num_nodes, num_group); batch_offset += batch.Size() * num_group; } @@ -858,8 +864,7 @@ class LaunchConfig { } public: - LaunchConfig(Context const* ctx, bst_feature_t n_features) - : ctx_{ctx}, n_features_{n_features} {} + LaunchConfig(Context const* ctx, bst_feature_t n_features) : ctx_{ctx}, n_features_{n_features} {} template void ForEachBatch(DMatrix* p_fmat, Fn&& fn) { @@ -1055,18 +1060,16 @@ class GPUPredictor : public xgboost::Predictor { } } - LaunchPredict(this->ctx_, false, enc::DeviceColumnsView{}, model, - [&](auto&& cfg, auto&& acc) { - using EncAccessor = std::remove_reference_t; - CHECK((std::is_same_v)); - using LoaderImpl = DeviceAdapterLoader; - using Loader = - typename common::GetValueT::template LoaderType; - cfg.template AllocShmem(); - cfg.template LaunchPredictKernel( - m->Value(), missing, n_features, d_model, acc, 0, &out_preds->predictions); - }); + LaunchPredict(this->ctx_, false, enc::DeviceColumnsView{}, model, [&](auto&& cfg, auto&& acc) { + using EncAccessor = std::remove_reference_t; + CHECK((std::is_same_v)); + using LoaderImpl = DeviceAdapterLoader; + using Loader = + typename common::GetValueT::template LoaderType; + cfg.template AllocShmem(); + cfg.template LaunchPredictKernel(m->Value(), missing, n_features, d_model, acc, 0, + &out_preds->predictions); + }); } [[nodiscard]] bool InplacePredict(std::shared_ptr p_m, gbm::GBTreeModel const& model, @@ -1091,62 +1094,11 @@ class GPUPredictor : public xgboost::Predictor { std::vector const* tree_weights, bool approximate, int, unsigned) const override { xgboost_NVTX_FN_RANGE(); - StringView not_implemented{ - "contribution is not implemented in the GPU predictor, use CPU instead."}; if (approximate) { - LOG(FATAL) << "Approximated " << not_implemented; - } - if (tree_weights != nullptr) { - LOG(FATAL) << "Dart booster feature " << not_implemented; + LOG(FATAL) << "Approximated contribution is not implemented in the GPU predictor, use CPU " + "instead."; } - CHECK(!p_fmat->Info().IsColumnSplit()) - << "Predict contribution support for column-wise data split is not yet implemented."; - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); - out_contribs->SetDevice(ctx_->Device()); - tree_end = GetTreeLimit(model.trees, tree_end); - - const int ngroup = model.learner_model_param->num_output_group; - CHECK_NE(ngroup, 0); - // allocate space for (number of features + bias) times the number of rows - size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias - auto dim_size = contributions_columns * model.learner_model_param->num_output_group; - // Output shape: [n_samples, n_classes, n_features + 1] - out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); - out_contribs->Fill(0.0f); - auto phis = out_contribs->DeviceSpan(); - - dh::device_vector> device_paths; - DeviceModel d_model{this->ctx_->Device(), model, true, 0, tree_end, &this->model_mu_, - CopyViews{this->ctx_}}; - - auto new_enc = - p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx_) : enc::DeviceColumnsView{}; - - dh::device_vector categories; - ExtractPaths(ctx_, &device_paths, model, d_model, &categories); - - LaunchShap(this->ctx_, new_enc, model, [&](auto&& cfg, auto&& acc) { - using Config = common::GetValueT; - using EncAccessor = typename Config::EncAccessorT; - - cfg.ForEachBatch( - p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { - auto begin = dh::tbegin(phis) + base_rowid * dim_size; - gpu_treeshap::GPUTreeShap>( - loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); - }); - }); - - // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(ctx_->Device()); - const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); - - auto base_score = model.learner_model_param->BaseScore(ctx_); - bst_idx_t n_samples = p_fmat->Info().num_row_; - dh::LaunchN(n_samples * ngroup, ctx_->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { - auto [_, gid] = linalg::UnravelIndex(idx, n_samples, ngroup); - phis[(idx + 1) * contributions_columns - 1] += margin.empty() ? base_score(gid) : margin[idx]; - }); + interpretability::ShapValues(ctx_, p_fmat, out_contribs, model, tree_end, tree_weights, 0, 0); } void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, @@ -1154,62 +1106,12 @@ class GPUPredictor : public xgboost::Predictor { std::vector const* tree_weights, bool approximate) const override { xgboost_NVTX_FN_RANGE(); - std::string not_implemented{ - "contribution is not implemented in GPU predictor, use cpu instead."}; if (approximate) { - LOG(FATAL) << "Approximated " << not_implemented; - } - if (tree_weights != nullptr) { - LOG(FATAL) << "Dart booster feature " << not_implemented; + LOG(FATAL) << "Approximated contribution is not implemented in GPU predictor, use cpu " + "instead."; } - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); - out_contribs->SetDevice(ctx_->Device()); - tree_end = GetTreeLimit(model.trees, tree_end); - - const int ngroup = model.learner_model_param->num_output_group; - CHECK_NE(ngroup, 0); - // allocate space for (number of features + bias) times the number of rows - size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias - auto dim_size = - contributions_columns * contributions_columns * model.learner_model_param->num_output_group; - out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); - out_contribs->Fill(0.0f); - auto phis = out_contribs->DeviceSpan(); - - dh::device_vector> device_paths; - DeviceModel d_model{this->ctx_->Device(), model, true, 0, tree_end, &this->model_mu_, - CopyViews{this->ctx_}}; - - dh::device_vector categories; - ExtractPaths(ctx_, &device_paths, model, d_model, &categories); - auto new_enc = - p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx_) : enc::DeviceColumnsView{}; - - LaunchShap(this->ctx_, new_enc, model, [&](auto&& cfg, auto&& acc) { - using Config = common::GetValueT; - using EncAccessor = typename Config::EncAccessorT; - - cfg.ForEachBatch( - p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { - auto begin = dh::tbegin(phis) + base_rowid * dim_size; - gpu_treeshap::GPUTreeShapInteractions>( - loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); - }); - }); - - // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(ctx_->Device()); - const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); - - auto base_score = model.learner_model_param->BaseScore(ctx_); - size_t n_features = model.learner_model_param->num_feature; - bst_idx_t n_samples = p_fmat->Info().num_row_; - dh::LaunchN(n_samples * ngroup, ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) { - auto [ridx, gidx] = linalg::UnravelIndex(idx, n_samples, ngroup); - phis[gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gidx, n_features, n_features, - n_features)] += - margin.empty() ? base_score(gidx) : margin[idx]; - }); + interpretability::ShapInteractionValues(ctx_, p_fmat, out_contribs, model, tree_end, + tree_weights, approximate); } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index a13f88b3c6a1..8bb20f125755 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -10,6 +10,7 @@ #include // for Json #include // for Learner #include // for Vector +#include // for ObjFunction #include #include // for unique_ptr @@ -50,23 +51,23 @@ Args BaseParams(Context const* ctx, std::string objective, std::string max_depth {"device", ctx->IsSycl() ? "cpu" : ctx->DeviceName()}}; } -gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, +gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, Args const& model_args, LearnerModelParam* out_param) { Json model{Object{}}; learner->SaveModel(&model); - CHECK(IsA(model)) << model; auto const& model_obj = get(model); - auto learner_it = model_obj.find("learner"); - CHECK(learner_it != model_obj.cend()) << model; - CHECK(IsA(learner_it->second)) << model; - auto const& learner_obj = get(learner_it->second); - + auto const& learner_obj = get(model_obj.at("learner")); auto const& lmp = get(learner_obj.at("learner_model_param")); - auto const& num_feature = get(lmp.at("num_feature")); - auto const& num_class = get(lmp.at("num_class")); - auto const& num_target = get(lmp.at("num_target")); - auto const& base_score_str = get(lmp.at("base_score")); + + auto get_or = [&](char const* key, std::string dft) { + auto it = lmp.find(key); + return it == lmp.cend() ? dft : get(it->second); + }; + auto const& num_feature = get_or("num_feature", "0"); + auto const& num_class = get_or("num_class", "0"); + auto const& num_target = get_or("num_target", "1"); + auto const& base_score_str = get_or("base_score", "0"); common::ParamArray base_score_arr{"base_score"}; std::stringstream ss; @@ -78,6 +79,17 @@ gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, auto& h_base = base_score_vec.Data()->HostVector(); h_base.assign(base_score_arr.cbegin(), base_score_arr.cend()); + std::string objective{"reg:squarederror"}; + for (auto const& kv : model_args) { + if (kv.first == "objective") { + objective = kv.second; + break; + } + } + auto obj = std::unique_ptr(ObjFunction::Create(objective, ctx)); + obj->Configure(model_args); + obj->ProbToMargin(&base_score_vec); + auto n_features = static_cast(std::stol(num_feature)); auto n_classes = static_cast(std::stol(num_class)); auto n_targets = static_cast(std::stol(num_target)); @@ -86,10 +98,8 @@ gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, MultiStrategy::kOneOutputPerTree}; gbm::GBTreeModel gbtree{out_param, ctx}; - auto gbm_it = learner_obj.find("gradient_booster"); - CHECK(gbm_it != learner_obj.cend()) << model; - CHECK(IsA(gbm_it->second)) << model; - gbtree.LoadModel(gbm_it->second); + auto const& gbm_obj = get(learner_obj.at("gradient_booster")); + gbtree.LoadModel(gbm_obj.at("model")); return gbtree; } } // namespace @@ -184,7 +194,7 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { size_t const n_outputs = margin_predt.HostVector().size() / kRows; LearnerModelParam mparam; - auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), &mparam); + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), model_args, &mparam); HostDeviceVector shap_values; interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, gbtree, 0, nullptr, 0, 0); @@ -259,7 +269,7 @@ TEST(Predictor, ApproxContribsBasic) { learner->Predict(dmat, true, &margin_predt, 0, 0, false, false, false, false, false); LearnerModelParam mparam; - auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), &mparam); + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), args, &mparam); HostDeviceVector approx_contribs; interpretability::ApproxFeatureImportance(dmat->Ctx(), dmat.get(), &approx_contribs, gbtree, 0, From e36195865247cc60c1215f62527b077f968d5154 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 4 Feb 2026 08:46:42 -0800 Subject: [PATCH 03/17] Remove some duplication --- src/predictor/gpu_predictor.cu | 276 --------------------------------- 1 file changed, 276 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 370807fc8f10..88d5557345ba 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,7 +1,6 @@ /** * Copyright 2017-2026, XGBoost Contributors */ -#include #include #include #include @@ -341,221 +340,6 @@ struct CopyViews { using DeviceModel = GBTreeModelView; } // namespace -struct ShapSplitCondition { - ShapSplitCondition() = default; - XGBOOST_DEVICE - ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, bool is_missing_branch, - common::CatBitField cats) - : feature_lower_bound(feature_lower_bound), - feature_upper_bound(feature_upper_bound), - is_missing_branch(is_missing_branch), - categories{std::move(cats)} { - assert(feature_lower_bound <= feature_upper_bound); - } - - /*! Feature values >= lower and < upper flow down this path. */ - float feature_lower_bound; - float feature_upper_bound; - /*! Feature value set to true flow down this path. */ - common::CatBitField categories; - /*! Do missing values flow down this path? */ - bool is_missing_branch; - - // Does this instance flow down this path? - [[nodiscard]] XGBOOST_DEVICE bool EvaluateSplit(float x) const { - // is nan - if (isnan(x)) { - return is_missing_branch; - } - if (categories.Capacity() != 0) { - auto cat = static_cast(x); - return categories.Check(cat); - } else { - return x >= feature_lower_bound && x < feature_upper_bound; - } - } - - // the &= op in bitfiled is per cuda thread, this one loops over the entire - // bitfield. - XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l, - common::CatBitField r) { - if (l.Data() == r.Data()) { - return l; - } - if (l.Capacity() > r.Capacity()) { - cuda::std::swap(l, r); - } - for (size_t i = 0; i < r.Bits().size(); ++i) { - l.Bits()[i] &= r.Bits()[i]; - } - return l; - } - - // Combine two split conditions on the same feature - XGBOOST_DEVICE void Merge(ShapSplitCondition other) { - // Combine duplicate features - if (categories.Capacity() != 0 || other.categories.Capacity() != 0) { - categories = Intersect(categories, other.categories); - } else { - feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); - feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); - } - is_missing_branch = is_missing_branch && other.is_missing_branch; - } -}; - -struct PathInfo { - std::size_t length; - // Node index in tree. - // -1 if not a leaf (internal split node) - bst_node_t nidx; - bst_tree_t tree_idx; - - [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return nidx != -1; } -}; -static_assert(sizeof(PathInfo) == 16); - -auto MakeTreeSegments(Context const* ctx, bst_tree_t tree_begin, bst_tree_t tree_end, - gbm::GBTreeModel const& model) { - // Copy decision trees to device - auto tree_segments = HostDeviceVector({}, ctx->Device()); - auto& h_tree_segments = tree_segments.HostVector(); - h_tree_segments.reserve((tree_end - tree_begin) + 1); - std::size_t sum = 0; - h_tree_segments.push_back(sum); - for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - auto const& p_tree = model.trees.at(tree_idx); - CHECK(!p_tree->IsMultiTarget()) << " SHAP " << MTNotImplemented(); - sum += p_tree->Size(); - h_tree_segments.push_back(sum); - } - return tree_segments; -} - -// Transform model into path element form for GPUTreeShap -void ExtractPaths(Context const* ctx, - dh::device_vector>* paths, - gbm::GBTreeModel const& h_model, DeviceModel const& d_model, - dh::device_vector* path_categories) { - curt::SetDevice(ctx->Ordinal()); - - // Path length and tree index for all leaf nodes - dh::caching_device_vector info(d_model.n_nodes); - auto d_trees = d_model.Trees(); // subset of trees - auto tree_segments = MakeTreeSegments(ctx, d_model.tree_begin, d_model.tree_end, h_model); - CHECK_EQ(tree_segments.ConstHostVector().back(), d_model.n_nodes); - auto d_tree_segments = tree_segments.ConstDeviceSpan(); - - auto path_it = dh::MakeIndexTransformIter( - cuda::proclaim_return_type([=] __device__(size_t idx) -> PathInfo { - bst_tree_t const tree_idx = dh::SegmentId(d_tree_segments, idx); - bst_node_t const nidx = idx - d_tree_segments[tree_idx]; - auto const& tree = cuda::std::get(d_trees[tree_idx]); - if (!tree.IsLeaf(nidx) || tree.IsDeleted(nidx)) { - // -1 if it's an internal split node - return PathInfo{0, -1, 0}; - } - // Get the path length for leaf - std::size_t path_length = 1; - auto iter_nidx = nidx; - while (!tree.IsRoot(iter_nidx)) { - iter_nidx = tree.Parent(iter_nidx); - path_length++; - } - return PathInfo{path_length, nidx, tree_idx}; - })); - auto end = thrust::copy_if( - ctx->CUDACtx()->CTP(), path_it, path_it + d_model.n_nodes, info.begin(), - cuda::proclaim_return_type([=] __device__(PathInfo const& e) { return e.IsLeaf(); })); - - info.resize(end - info.begin()); - using LenT = decltype(std::declval().length); - auto length_iterator = dh::MakeTransformIterator( - info.begin(), cuda::proclaim_return_type( - [=] __device__(PathInfo const& info) { return info.length; })); - dh::caching_device_vector path_segments(info.size() + 1); - thrust::exclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size() + 1, - path_segments.begin()); - - paths->resize(path_segments.back()); - - auto d_paths = dh::ToSpan(*paths); - auto d_info = info.data().get(); - auto d_tree_groups = d_model.tree_groups; - auto d_path_segments = path_segments.data().get(); - - std::size_t max_cat = 0; - if (std::any_of(h_model.trees.cbegin(), h_model.trees.cend(), - [](auto const& p_tree) { return p_tree->HasCategoricalSplit(); })) { - auto max_elem_it = dh::MakeIndexTransformIter([=] __device__(std::size_t i) -> std::size_t { - auto tree_idx = dh::SegmentId(d_tree_segments, i); - auto nidx = i - d_tree_segments[tree_idx]; - return cuda::std::get(d_trees[tree_idx]) - .GetCategoriesMatrix() - .node_ptr[nidx] - .size; - }); - auto max_cat_it = - thrust::max_element(ctx->CUDACtx()->CTP(), max_elem_it, max_elem_it + d_model.n_nodes); - dh::CachingDeviceUVector d_max_cat(1); - auto s_max_cat = dh::ToSpan(d_max_cat); - dh::LaunchN(1, ctx->CUDACtx()->Stream(), - [=] __device__(std::size_t) { s_max_cat[0] = *max_cat_it; }); - dh::safe_cuda( - cudaMemcpy(&max_cat, s_max_cat.data(), s_max_cat.size_bytes(), cudaMemcpyDeviceToHost)); - CHECK_GE(max_cat, 1); - path_categories->resize(max_cat * paths->size()); - } - - common::Span d_path_categories = dh::ToSpan(*path_categories); - - dh::LaunchN(info.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { - auto path_info = d_info[idx]; - auto tree = cuda::std::get(d_trees[path_info.tree_idx]); - std::int32_t group = d_tree_groups[path_info.tree_idx]; - auto child_nidx = path_info.nidx; - - float v = tree.LeafValue(child_nidx); - const float inf = std::numeric_limits::infinity(); - size_t output_position = d_path_segments[idx + 1] - 1; - - while (!tree.IsRoot(child_nidx)) { - auto parent_nidx = tree.Parent(child_nidx); - double child_cover = tree.SumHess(child_nidx); - double parent_cover = tree.SumHess(parent_nidx); - double zero_fraction = child_cover / parent_cover; - - bool is_left_path = tree.LeftChild(parent_nidx) == child_nidx; - bool is_missing_path = (!tree.DefaultLeft(parent_nidx) && !is_left_path) || - (tree.DefaultLeft(parent_nidx) && is_left_path); - - float lower_bound = -inf; - float upper_bound = inf; - common::CatBitField bits; - if (common::IsCat(tree.cats.split_type, tree.Parent(child_nidx))) { - auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat); - auto node_cats = tree.NodeCats(tree.Parent(child_nidx)); - SPAN_CHECK(path_cats.size() >= node_cats.size()); - for (size_t i = 0; i < node_cats.size(); ++i) { - path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i]; - } - bits = common::CatBitField{path_cats}; - } else { - lower_bound = is_left_path ? -inf : tree.SplitCond(parent_nidx); - upper_bound = is_left_path ? tree.SplitCond(parent_nidx) : inf; - } - d_paths[output_position--] = gpu_treeshap::PathElement{ - idx, tree.SplitIndex(parent_nidx), - group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits}, - zero_fraction, v}; - - child_nidx = parent_nidx; - } - // Root node has feature -1 - d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v}; - }); -} - namespace { template [[nodiscard]] std::size_t SharedMemoryBytes(std::size_t n_features, std::size_t max_shmem_bytes) { @@ -782,23 +566,6 @@ class ColumnSplitHelper { using cuda_impl::MakeCatAccessor; -template -struct ShapSparsePageLoader { - public: - using SupportShmemLoad = std::false_type; - - SparsePageView data; - EncAccessor acc; - - template - [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { - auto fvalue = data.GetElement(ridx, fidx); - return acc(fvalue, fidx); - } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } -}; - // Provide configuration for launching the predict kernel. template class LaunchConfig { @@ -893,35 +660,6 @@ class LaunchConfig { } } } - // Used by the SHAP methods. - template - void ForEachBatch(DMatrix* p_fmat, EncAccessor&& acc, Fn&& fn) { - if (p_fmat->PageExists()) { - for (auto& page : p_fmat->GetBatches()) { - // Shap kernel doesn't use shared memory to stage data. - SparsePageView batch{ctx_, page, n_features_}; - auto loader = ShapSparsePageLoader{batch, acc}; - fn(std::move(loader), page.base_rowid); - } - } else { - p_fmat->Info().feature_types.SetDevice(ctx_->Device()); - auto feature_types = p_fmat->Info().feature_types.ConstDeviceSpan(); - - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - page.Impl()->Visit(ctx_, feature_types, [&](auto&& batch) { - using Acc = std::remove_reference_t; - // No shared memory use for ellpack - auto loader = EllpackLoader{batch, - /*use_shared=*/false, - this->n_features_, - batch.NumRows(), - std::numeric_limits::quiet_NaN(), - std::forward(acc)}; - fn(std::move(loader), batch.base_rowid); - }); - } - } - } }; template @@ -952,20 +690,6 @@ void LaunchPredict(Context const* ctx, bool is_dense, enc::DeviceColumnsView con } } -template -void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, - gbm::GBTreeModel const& model, Kernel&& launch) { - if (model.Cats() && model.Cats()->HasCategorical() && new_enc.HasCategorical()) { - auto [acc, mapping] = MakeCatAccessor(ctx, new_enc, model.Cats()); - auto cfg = - LaunchConfig{ctx, model.learner_model_param->num_feature}; - launch(std::move(cfg), std::move(acc)); - } else { - auto cfg = - LaunchConfig{ctx, model.learner_model_param->num_feature}; - launch(std::move(cfg), NoOpAccessor{}); - } -} } // anonymous namespace class GPUPredictor : public xgboost::Predictor { From 71480865efe06991d200ead3e1614f59fa2e62b2 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 9 Feb 2026 06:30:56 -0800 Subject: [PATCH 04/17] Fix tests --- src/interpretability/shap.cc | 24 +++++------ src/interpretability/shap.cu | 68 ++++++++++++++++---------------- src/interpretability/shap.h | 67 +++++++++++++++---------------- tests/cpp/predictor/test_shap.cc | 10 ++++- 4 files changed, 87 insertions(+), 82 deletions(-) diff --git a/src/interpretability/shap.cc b/src/interpretability/shap.cc index f2dee5eb8a48..619b9c89d95e 100644 --- a/src/interpretability/shap.cc +++ b/src/interpretability/shap.cc @@ -136,10 +136,10 @@ class GHistIndexMatrixView { }; } // namespace -void ShapValuesCPU(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, - unsigned condition_feature) { +namespace cpu_impl { +void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, unsigned condition_feature) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; @@ -248,10 +248,9 @@ void ShapValuesCPU(Context const *ctx, DMatrix *p_fmat, HostDeviceVector } } -void ApproxFeatureImportanceCPU(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights) { +void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; @@ -354,10 +353,10 @@ void ApproxFeatureImportanceCPU(Context const *ctx, DMatrix *p_fmat, } } -void ShapInteractionValuesCPU(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, - bool approximate) { +void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " @@ -414,4 +413,5 @@ void ShapInteractionValuesCPU(Context const *ctx, DMatrix *p_fmat, } } } +} // namespace cpu_impl } // namespace xgboost::interpretability diff --git a/src/interpretability/shap.cu b/src/interpretability/shap.cu index 1e295853224e..0dc6dccf7687 100644 --- a/src/interpretability/shap.cu +++ b/src/interpretability/shap.cu @@ -5,8 +5,8 @@ #include #include #include +#include #include -#include #include #include @@ -44,8 +44,8 @@ namespace xgboost::interpretability::cuda_impl { namespace { -using cuda_impl::StaticBatch; using predictor::GBTreeModelView; +using ::xgboost::cuda_impl::StaticBatch; struct SparsePageView { common::Span d_data; @@ -115,31 +115,29 @@ struct EllpackLoader { public: using SupportShmemLoad = std::false_type; - private: - EncAccessor acc_; - Accessor data_; - - public: - bool use_shared{false}; - bst_feature_t num_features; - bst_idx_t num_rows; - float missing; + Accessor matrix; + EncAccessor acc; XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t n_features, bst_idx_t num_rows, float missing, EncAccessor&& acc) - : acc_{std::forward(acc)}, - data_{m}, - num_features{n_features}, - num_rows{num_rows}, - missing{missing} {} + : matrix{std::move(m)}, acc{std::forward(acc)} {} template [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { - auto value = data_.GetElement(ridx, fidx); - return acc_(value, fidx); + auto gidx = matrix.template GetBinIndex(ridx, fidx); + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + if (common::IsCat(matrix.feature_types, fidx)) { + return this->acc(matrix.gidx_fvalue_map[gidx], fidx); + } + if (gidx == matrix.feature_segments[fidx]) { + return matrix.min_fvalue[fidx]; + } + return matrix.gidx_fvalue_map[gidx - 1]; } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return num_rows; } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return num_features; } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return matrix.n_rows; } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return matrix.NumFeatures(); } }; using TreeViewVar = cuda::std::variant; @@ -358,20 +356,20 @@ void ExtractPaths(Context const* ctx, }); } -template +template class LaunchConfig { public: - using EncAccessorT = EncAccessorT; + using EncAccessorT = EncAccessor; explicit LaunchConfig(Context const* ctx, bst_feature_t n_features) : ctx_{ctx}, n_features_{n_features} {} template - void ForEachBatch(DMatrix* p_fmat, EncAccessorT&& acc, Fn&& fn) { + void ForEachBatch(DMatrix* p_fmat, EncAccessor&& acc, Fn&& fn) { if (p_fmat->PageExists()) { for (auto& page : p_fmat->GetBatches()) { SparsePageView batch{ctx_, page, n_features_}; - auto loader = ShapSparsePageLoader{batch, acc}; + auto loader = ShapSparsePageLoader{batch, acc}; fn(std::move(loader), page.base_rowid); } } else { @@ -386,7 +384,7 @@ class LaunchConfig { this->n_features_, batch.NumRows(), std::numeric_limits::quiet_NaN(), - std::forward(acc)}; + std::forward(acc)}; fn(std::move(loader), batch.base_rowid); }); } @@ -402,7 +400,7 @@ template void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, gbm::GBTreeModel const& model, Kernel&& launch) { if (model.Cats() && model.Cats()->HasCategorical() && new_enc.HasCategorical()) { - auto [acc, mapping] = cuda_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + auto [acc, mapping] = ::xgboost::cuda_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); auto cfg = LaunchConfig{ctx, model.learner_model_param->num_feature}; launch(std::move(cfg), std::move(acc)); @@ -414,9 +412,9 @@ void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, } } // namespace -void ShapValuesCUDA(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, - gbm::GBTreeModel const& model, bst_tree_t tree_end, - std::vector const* tree_weights, int, unsigned) { +void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int, unsigned) { xgboost_NVTX_FN_RANGE(); StringView not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; @@ -469,10 +467,10 @@ void ShapValuesCUDA(Context const* ctx, DMatrix* p_fmat, HostDeviceVector }); } -void ShapInteractionValuesCUDA(Context const* ctx, DMatrix* p_fmat, - HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, - bst_tree_t tree_end, std::vector const* tree_weights, - bool approximate) { +void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + bool approximate) { xgboost_NVTX_FN_RANGE(); std::string not_implemented{"contribution is not implemented in GPU predictor, use cpu instead."}; if (approximate) { @@ -528,8 +526,8 @@ void ShapInteractionValuesCUDA(Context const* ctx, DMatrix* p_fmat, }); } -void ApproxFeatureImportanceCUDA(Context const* ctx, DMatrix*, HostDeviceVector*, - gbm::GBTreeModel const&, bst_tree_t, std::vector const*) { +void ApproxFeatureImportance(Context const* ctx, DMatrix*, HostDeviceVector*, + gbm::GBTreeModel const&, bst_tree_t, std::vector const*) { StringView not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; LOG(FATAL) << "Approximated " << not_implemented; diff --git a/src/interpretability/shap.h b/src/interpretability/shap.h index 43bf98a48c42..2c7c71455e60 100644 --- a/src/interpretability/shap.h +++ b/src/interpretability/shap.h @@ -14,34 +14,34 @@ struct GBTreeModel; } // namespace xgboost::gbm namespace xgboost::interpretability { -void ShapValuesCPU(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, - unsigned condition_feature); +namespace cpu_impl { +void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, unsigned condition_feature); -void ApproxFeatureImportanceCPU(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights); +void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights); -void ShapInteractionValuesCPU(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, - bool approximate); +void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate); +} // namespace cpu_impl #if defined(XGBOOST_USE_CUDA) -void ShapValuesCUDA(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, - unsigned condition_feature); -void ApproxFeatureImportanceCUDA(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights); -void ShapInteractionValuesCUDA(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, - bool approximate); +namespace cuda_impl { +void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, unsigned condition_feature); +void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights); +void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate); +} // namespace cuda_impl #endif // defined(XGBOOST_USE_CUDA) inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, @@ -50,13 +50,13 @@ inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVectorIsCUDA()) { - ShapValuesCUDA(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, - condition_feature); + cuda_impl::ShapValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, + condition_feature); return; } #endif // defined(XGBOOST_USE_CUDA) - ShapValuesCPU(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, - condition_feature); + cpu_impl::ShapValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, + condition_feature); } inline void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, @@ -65,11 +65,11 @@ inline void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, std::vector const *tree_weights) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { - ApproxFeatureImportanceCUDA(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + cuda_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); return; } #endif // defined(XGBOOST_USE_CUDA) - ApproxFeatureImportanceCPU(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + cpu_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); } inline void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, @@ -78,11 +78,12 @@ inline void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, std::vector const *tree_weights, bool approximate) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { - ShapInteractionValuesCUDA(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, - approximate); + cuda_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, + approximate); return; } #endif // defined(XGBOOST_USE_CUDA) - ShapInteractionValuesCPU(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, approximate); + cpu_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, + approximate); } } // namespace xgboost::interpretability diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index 8bb20f125755..1a36e90732eb 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -89,13 +89,19 @@ gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, Args cons auto obj = std::unique_ptr(ObjFunction::Create(objective, ctx)); obj->Configure(model_args); obj->ProbToMargin(&base_score_vec); + // Keep both host/device views readable, matching LearnerModelParam invariants. + std::as_const(base_score_vec).HostView(); + if (!ctx->Device().IsCPU()) { + std::as_const(base_score_vec).View(ctx->Device()); + } auto n_features = static_cast(std::stol(num_feature)); auto n_classes = static_cast(std::stol(num_class)); auto n_targets = static_cast(std::stol(num_target)); auto n_groups = static_cast(std::max(n_classes, n_targets)); - *out_param = LearnerModelParam{n_features, std::move(base_score_vec), n_groups, n_targets, - MultiStrategy::kOneOutputPerTree}; + LearnerModelParam tmp{n_features, std::move(base_score_vec), n_groups, n_targets, + MultiStrategy::kOneOutputPerTree}; + out_param->Copy(tmp); gbm::GBTreeModel gbtree{out_param, ctx}; auto const& gbm_obj = get(learner_obj.at("gradient_booster")); From 8e75d36fe948f5f3ab390087780b499879b28325 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 10 Feb 2026 01:07:25 -0800 Subject: [PATCH 05/17] Faster tests --- tests/cpp/predictor/test_shap.cc | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index 1a36e90732eb..cf4358bef610 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -116,31 +116,31 @@ std::vector BuildShapTestCases(Context const* ctx) { { // small dense, shallow tree - auto dmat = RandomDataGenerator(32, 6, 0.0).Device(device).GenerateDMatrix(); + auto dmat = RandomDataGenerator(16, 4, 0.0).Device(device).GenerateDMatrix(); SetLabels(dmat.get(), 1); cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "2")); } { // medium dense training DMatrix, moderate depth - auto dmat = RandomDataGenerator(512, 10, 0.0).Device(device).GenerateDMatrix(true); + auto dmat = RandomDataGenerator(64, 6, 0.0).Device(device).GenerateDMatrix(true); SetLabels(dmat.get(), 1); - cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "6")); + cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "4")); } { // quantile DMatrix with explicit bins, deeper tree auto dmat = - RandomDataGenerator(2048, 12, 0.0).Bins(64).Device(device).GenerateQuantileDMatrix(false); + RandomDataGenerator(128, 8, 0.0).Bins(32).Device(device).GenerateQuantileDMatrix(false); SetLabels(dmat.get(), 1); - auto args = BaseParams(ctx, "reg:squarederror", "8"); - args.emplace_back("max_bin", "64"); + auto args = BaseParams(ctx, "reg:squarederror", "5"); + args.emplace_back("max_bin", "32"); cases.emplace_back(dmat, std::move(args)); } { // external memory quantile DMatrix, moderate depth - bst_bin_t max_bin{64}; + bst_bin_t max_bin{32}; auto dmat = RandomDataGenerator(4096, 10, 0.0) .Batches(2) .Bins(max_bin) @@ -154,30 +154,30 @@ std::vector BuildShapTestCases(Context const* ctx) { { // external memory sparse page DMatrix, moderate depth - auto dmat = RandomDataGenerator(4096, 10, 0.0) + auto dmat = RandomDataGenerator(256, 8, 0.0) .Batches(2) .Device(device) .GenerateSparsePageDMatrix("shap_extmem", true); SetLabels(dmat.get(), 1); - cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "6")); + cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "4")); } { // multi-class dense training DMatrix, medium depth bst_target_t n_classes{3}; auto dmat = - RandomDataGenerator(256, 8, 0.0).Classes(n_classes).Device(device).GenerateDMatrix(true); + RandomDataGenerator(64, 6, 0.0).Classes(n_classes).Device(device).GenerateDMatrix(true); SetLabels(dmat.get(), n_classes); - auto args = BaseParams(ctx, "multi:softprob", "4"); + auto args = BaseParams(ctx, "multi:softprob", "3"); args.emplace_back("num_class", std::to_string(n_classes)); cases.emplace_back(dmat, std::move(args)); } { - // large dense, deeper tree and classification objective - auto dmat = RandomDataGenerator(10000, 12, 0.0).Device(device).GenerateDMatrix(); + // compact dense classification case to keep runtime bounded + auto dmat = RandomDataGenerator(256, 8, 0.0).Device(device).GenerateDMatrix(); SetLabels(dmat.get(), 1); - cases.emplace_back(dmat, BaseParams(ctx, "binary:logistic", "10")); + cases.emplace_back(dmat, BaseParams(ctx, "binary:logistic", "4")); } return cases; @@ -191,7 +191,7 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { std::unique_ptr learner{Learner::Create({p_dmat})}; learner->SetParams(model_args); learner->Configure(); - for (size_t i = 0; i < 5; ++i) { + for (size_t i = 0; i < 2; ++i) { learner->UpdateOneIter(i, p_dmat); } From 9630ca19b1a2bcb9db958246bbd2d8fc7bb2effa Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 11 Feb 2026 03:13:13 -0800 Subject: [PATCH 06/17] Use span --- include/xgboost/predictor.h | 12 +++++----- src/gbm/gbtree.cc | 38 ++++++++++++++------------------ src/gbm/gbtree.h | 36 ++++++++++++++---------------- src/interpretability/shap.cc | 23 +++++++++++++------ src/interpretability/shap.cu | 10 ++++----- src/interpretability/shap.h | 19 ++++++++-------- src/predictor/cpu_predictor.cc | 6 ++--- src/predictor/gpu_predictor.cu | 4 ++-- tests/cpp/predictor/test_shap.cc | 6 ++--- 9 files changed, 78 insertions(+), 76 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 020e0a59d1e8..695abca5c35c 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -12,6 +12,7 @@ #include #include #include +#include #include // for function #include // for shared_ptr @@ -156,14 +157,14 @@ class Predictor { virtual void PredictContribution(DMatrix* dmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end = 0, - std::vector const* tree_weights = nullptr, + common::Span tree_weights = {}, bool approximate = false, int condition = 0, unsigned condition_feature = 0) const = 0; virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end = 0, - std::vector const* tree_weights = nullptr, + common::Span tree_weights = {}, bool approximate = false) const = 0; /** @@ -181,8 +182,7 @@ class Predictor { struct PredictorReg : public dmlc::FunctionRegEntryBase> {}; -#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ - static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \ - __make_##PredictorReg##_##UniqueId##__ = \ - ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name) +#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& __make_##PredictorReg##_##UniqueId##__ = \ + ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name) } // namespace xgboost diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 06f89242b1d0..f621c6a4e20c 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -112,8 +112,7 @@ void GBTree::Configure(Args const& cfg) { #if defined(XGBOOST_USE_SYCL) if (!sycl_predictor_) { - sycl_predictor_ = - std::unique_ptr(Predictor::Create("sycl_predictor", this->ctx_)); + sycl_predictor_ = std::unique_ptr(Predictor::Create("sycl_predictor", this->ctx_)); } sycl_predictor_->Configure(cfg); #endif // defined(XGBOOST_USE_SYCL) @@ -639,9 +638,9 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, return gpu_predictor_; } else { #if defined(XGBOOST_USE_SYCL) - common::AssertSYCLSupport(); - CHECK(sycl_predictor_); - return sycl_predictor_; + common::AssertSYCLSupport(); + CHECK(sycl_predictor_); + return sycl_predictor_; #endif // defined(XGBOOST_USE_SYCL) } @@ -676,7 +675,6 @@ void GPUDartInplacePredictInc(common::Span /*out_predts*/, common::Spanmodel_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented(); auto& predictor = this->GetPredictor(training, &p_out_preds->predictions, p_fmat); CHECK(predictor); - predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, - model_); + predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_); p_out_preds->version = 0; auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); auto n_groups = model_.learner_model_param->num_output_group; @@ -784,8 +781,8 @@ class Dart : public GBTree { GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(), predts.predictions.DeviceSpan(), w, n_rows, n_groups, grp_idx); } else { - auto &h_out_predts = p_out_preds->predictions.HostVector(); - auto &h_predts = predts.predictions.ConstHostVector(); + auto& h_out_predts = p_out_preds->predictions.HostVector(); + auto& h_predts = predts.predictions.ConstHostVector(); common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) { const size_t offset = ridx * n_groups + grp_idx; h_out_predts[offset] += (h_predts[offset] * w); @@ -890,8 +887,8 @@ class Dart : public GBTree { bst_layer_t layer_begin, bst_layer_t layer_end, bool approximate) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); - cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, &weight_drop_, - approximate); + cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, + common::Span{weight_drop_}, approximate); } void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, @@ -899,7 +896,8 @@ class Dart : public GBTree { bool approximate) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end, - &weight_drop_, approximate); + common::Span{weight_drop_}, + approximate); } protected: @@ -942,10 +940,8 @@ class Dart : public GBTree { // size_t i = std::discrete_distribution(weight_drop.begin(), // weight_drop.end())(rnd); size_t i = std::discrete_distribution( - weight_drop_.size(), 0., static_cast(weight_drop_.size()), - [this](double x) -> double { - return weight_drop_[static_cast(x)]; - })(rnd); + weight_drop_.size(), 0., static_cast(weight_drop_.size()), + [this](double x) -> double { return weight_drop_[static_cast(x)]; })(rnd); idx_drop_.push_back(i); } } else { diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 739d196769f9..024a0d33d3f6 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -40,10 +40,7 @@ enum class TreeMethod : int { }; // boosting process types -enum class TreeProcessType : int { - kDefault = 0, - kUpdate = 1 -}; +enum class TreeProcessType : int { kDefault = 0, kUpdate = 1 }; // Sampling type for dart weights. enum class DartSampleType : std::int32_t { @@ -72,15 +69,16 @@ struct GBTreeTrainParam : public XGBoostParameter { .set_default(TreeProcessType::kDefault) .add_enum("default", TreeProcessType::kDefault) .add_enum("update", TreeProcessType::kUpdate) - .describe("Whether to run the normal boosting process that creates new trees,"\ - " or to update the trees in an existing model."); + .describe( + "Whether to run the normal boosting process that creates new trees," + " or to update the trees in an existing model."); DMLC_DECLARE_ALIAS(updater_seq, updater); DMLC_DECLARE_FIELD(tree_method) .set_default(TreeMethod::kAuto) - .add_enum("auto", TreeMethod::kAuto) - .add_enum("approx", TreeMethod::kApprox) - .add_enum("exact", TreeMethod::kExact) - .add_enum("hist", TreeMethod::kHist) + .add_enum("auto", TreeMethod::kAuto) + .add_enum("approx", TreeMethod::kApprox) + .add_enum("exact", TreeMethod::kExact) + .add_enum("hist", TreeMethod::kHist) .describe("Choice of tree construction method."); } }; @@ -268,10 +266,9 @@ class GBTree : public GradientBooster { } }); } else { - LOG(FATAL) - << "Unknown feature importance type, expected one of: " - << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" - << importance_type; + LOG(FATAL) << "Unknown feature importance type, expected one of: " + << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" + << importance_type; } if (importance_type == "gain" || importance_type == "cover") { for (size_t i = 0; i < gain_map.size(); ++i) { @@ -291,9 +288,8 @@ class GBTree : public GradientBooster { [[nodiscard]] CatContainer const* Cats() const override { return this->model_.Cats(); } - void PredictLeaf(DMatrix* p_fmat, - HostDeviceVector* out_preds, - uint32_t layer_begin, uint32_t layer_end) override { + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, uint32_t layer_begin, + uint32_t layer_end) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: [0, " "n_iteration), use model slicing instead."; @@ -306,7 +302,7 @@ class GBTree : public GradientBooster { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: [0, " "n_iteration), using model slicing instead."; - this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, nullptr, + this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, {}, approximate); } @@ -317,7 +313,7 @@ class GBTree : public GradientBooster { CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: [0, " "n_iteration), using model slicing instead."; this->GetPredictor(false)->PredictInteractionContributions(p_fmat, out_contribs, model_, - tree_end, nullptr, approximate); + tree_end, {}, approximate); } [[nodiscard]] std::vector DumpModel(const FeatureMap& fmap, bool with_stats, @@ -345,7 +341,7 @@ class GBTree : public GradientBooster { GBTreeTrainParam tparam_; // Tree training parameter tree::TrainParam tree_param_; - bool specified_updater_ {false}; + bool specified_updater_{false}; // the updaters that can be applied to each of tree std::vector> updaters_; // Predictors diff --git a/src/interpretability/shap.cc b/src/interpretability/shap.cc index 619b9c89d95e..2f4633ac7bdd 100644 --- a/src/interpretability/shap.cc +++ b/src/interpretability/shap.cc @@ -25,6 +25,13 @@ namespace xgboost::interpretability { namespace { +void ValidateTreeWeights(common::Span tree_weights, bst_tree_t tree_end) { + if (tree_weights.empty()) { + return; + } + CHECK_GE(tree_weights.size(), static_cast(tree_end)); +} + float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx, std::vector *mean_values) { float result; @@ -139,13 +146,14 @@ class GHistIndexMatrixView { namespace cpu_impl { void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature) { + common::Span tree_weights, int condition, unsigned condition_feature) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; MetaInfo const &info = p_fmat->Info(); // number of valid trees tree_end = predictor::GetTreeLimit(model.trees, tree_end); + ValidateTreeWeights(tree_weights, tree_end); size_t const ncolumns = model.learner_model_param->num_feature + 1; // allocate space for (number of features + bias) times the number of rows std::vector &contribs = out_contribs->HostVector(); @@ -193,7 +201,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou condition, condition_feature); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); } } if (base_margin.Size() != 0) { @@ -232,7 +240,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou condition, condition_feature); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); } } if (base_margin.Size() != 0) { @@ -250,12 +258,13 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights) { + bst_tree_t tree_end, common::Span tree_weights) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; MetaInfo const &info = p_fmat->Info(); tree_end = predictor::GetTreeLimit(model.trees, tree_end); + ValidateTreeWeights(tree_weights, tree_end); size_t const ncolumns = model.learner_model_param->num_feature + 1; std::vector &contribs = out_contribs->HostVector(); contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); @@ -299,7 +308,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); } } if (base_margin.Size() != 0) { @@ -337,7 +346,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); } } if (base_margin.Size() != 0) { @@ -355,7 +364,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, + bst_tree_t tree_end, common::Span tree_weights, bool approximate) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); diff --git a/src/interpretability/shap.cu b/src/interpretability/shap.cu index 0dc6dccf7687..2f02234205c8 100644 --- a/src/interpretability/shap.cu +++ b/src/interpretability/shap.cu @@ -414,11 +414,11 @@ void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, - std::vector const* tree_weights, int, unsigned) { + common::Span tree_weights, int, unsigned) { xgboost_NVTX_FN_RANGE(); StringView not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; - if (tree_weights != nullptr) { + if (!tree_weights.empty()) { LOG(FATAL) << "Dart booster feature " << not_implemented; } CHECK(!p_fmat->Info().IsColumnSplit()) @@ -469,14 +469,14 @@ void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* ou void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, - bst_tree_t tree_end, std::vector const* tree_weights, + bst_tree_t tree_end, common::Span tree_weights, bool approximate) { xgboost_NVTX_FN_RANGE(); std::string not_implemented{"contribution is not implemented in GPU predictor, use cpu instead."}; if (approximate) { LOG(FATAL) << "Approximated " << not_implemented; } - if (tree_weights != nullptr) { + if (!tree_weights.empty()) { LOG(FATAL) << "Dart booster feature " << not_implemented; } dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); @@ -527,7 +527,7 @@ void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, } void ApproxFeatureImportance(Context const* ctx, DMatrix*, HostDeviceVector*, - gbm::GBTreeModel const&, bst_tree_t, std::vector const*) { + gbm::GBTreeModel const&, bst_tree_t, common::Span) { StringView not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; LOG(FATAL) << "Approximated " << not_implemented; diff --git a/src/interpretability/shap.h b/src/interpretability/shap.h index 2c7c71455e60..3de35afedba2 100644 --- a/src/interpretability/shap.h +++ b/src/interpretability/shap.h @@ -8,6 +8,7 @@ #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/span.h" // for Span namespace xgboost::gbm { struct GBTreeModel; @@ -17,15 +18,15 @@ namespace xgboost::interpretability { namespace cpu_impl { void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature); + common::Span tree_weights, int condition, unsigned condition_feature); void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights); + bst_tree_t tree_end, common::Span tree_weights); void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, + bst_tree_t tree_end, common::Span tree_weights, bool approximate); } // namespace cpu_impl @@ -33,20 +34,20 @@ void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, namespace cuda_impl { void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature); + common::Span tree_weights, int condition, unsigned condition_feature); void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights); + bst_tree_t tree_end, common::Span tree_weights); void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, + bst_tree_t tree_end, common::Span tree_weights, bool approximate); } // namespace cuda_impl #endif // defined(XGBOOST_USE_CUDA) inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, + common::Span tree_weights, int condition, unsigned condition_feature) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { @@ -62,7 +63,7 @@ inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights) { + common::Span tree_weights) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { cuda_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); @@ -75,7 +76,7 @@ inline void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, inline void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, bool approximate) { + common::Span tree_weights, bool approximate) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { cuda_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 45387542a82a..8c3ef6d322c5 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1000,8 +1000,8 @@ class CPUPredictor : public Predictor { void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, const gbm::GBTreeModel &model, bst_tree_t ntree_limit, - std::vector const *tree_weights, bool approximate, - int condition, unsigned condition_feature) const override { + common::Span tree_weights, bool approximate, int condition, + unsigned condition_feature) const override { if (approximate) { interpretability::ApproxFeatureImportance(this->ctx_, p_fmat, out_contribs, model, ntree_limit, tree_weights); @@ -1013,7 +1013,7 @@ class CPUPredictor : public Predictor { void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t ntree_limit, - std::vector const *tree_weights, + common::Span tree_weights, bool approximate) const override { interpretability::ShapInteractionValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 88d5557345ba..0a6e8329acbf 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -815,7 +815,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, bst_tree_t tree_end, - std::vector const* tree_weights, bool approximate, int, + common::Span tree_weights, bool approximate, int, unsigned) const override { xgboost_NVTX_FN_RANGE(); if (approximate) { @@ -827,7 +827,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, - std::vector const* tree_weights, + common::Span tree_weights, bool approximate) const override { xgboost_NVTX_FN_RANGE(); if (approximate) { diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index cf4358bef610..d3bea9d0c4fe 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -203,13 +203,13 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), model_args, &mparam); HostDeviceVector shap_values; - interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, gbtree, 0, nullptr, 0, 0); + interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, gbtree, 0, {}, 0, 0); ASSERT_EQ(shap_values.HostVector().size(), kRows * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_values, margin_predt); HostDeviceVector shap_interactions; interpretability::ShapInteractionValues(dmat->Ctx(), p_dmat.get(), &shap_interactions, gbtree, 0, - nullptr, false); + {}, false); ASSERT_EQ(shap_interactions.HostVector().size(), kRows * (kCols + 1) * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_interactions, margin_predt); } @@ -279,7 +279,7 @@ TEST(Predictor, ApproxContribsBasic) { HostDeviceVector approx_contribs; interpretability::ApproxFeatureImportance(dmat->Ctx(), dmat.get(), &approx_contribs, gbtree, 0, - nullptr); + {}); auto const& h_margin = margin_predt.ConstHostVector(); auto const& h_contribs = approx_contribs.ConstHostVector(); From ea777318abbcba797c05ca15bcc1c2d084a96ae3 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 15 Feb 2026 09:06:34 -0800 Subject: [PATCH 07/17] Move shap stuff back into predictor folder --- src/predictor/cpu_predictor.cc | 169 +---------------- src/predictor/data_accessor.h | 190 +++++++++++++++++++ src/predictor/gpu_data_accessor.cuh | 113 +++++++++++ src/predictor/gpu_predictor.cu | 82 +------- src/{ => predictor}/interpretability/shap.cc | 98 +--------- src/{ => predictor}/interpretability/shap.cu | 129 ++----------- src/{ => predictor}/interpretability/shap.h | 0 tests/cpp/predictor/test_shap.cc | 2 +- 8 files changed, 336 insertions(+), 447 deletions(-) create mode 100644 src/predictor/data_accessor.h create mode 100644 src/predictor/gpu_data_accessor.cuh rename src/{ => predictor}/interpretability/shap.cc (82%) rename src/{ => predictor}/interpretability/shap.cu (81%) rename src/{ => predictor}/interpretability/shap.h (100%) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 8c3ef6d322c5..b2fb8a13adf6 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -20,10 +20,11 @@ #include "../data/gradient_index.h" // for GHistIndexMatrix #include "../data/proxy_dmatrix.h" // for DMatrixProxy #include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam -#include "../interpretability/shap.h" // for PredictContribution #include "array_tree_layout.h" // for ProcessArrayTree +#include "data_accessor.h" // for GHistIndexMatrixView, SparsePageView #include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG #include "gbtree_view.h" // for GBTreeModelView +#include "interpretability/shap.h" // for PredictContribution #include "predict_fn.h" // for GetNextNode, GetNextNodeMulti #include "utils.h" // for CheckProxyDMatrix #include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... @@ -233,172 +234,6 @@ bool ShouldUseBlock(DMatrix *p_fmat) { using cpu_impl::MakeCatAccessor; -// Convert a single sample in batch view to FVec -template -struct DataToFeatVec { - void Fill(bst_idx_t ridx, RegTree::FVec *p_feats) const { - auto &feats = *p_feats; - auto n_valid = static_cast(this)->DoFill(ridx, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - } - - // Fill the data into the feature vector. - void FVecFill(common::Range1d const &block, bst_feature_t n_features, - common::Span s_feats_vec) const { - auto feats_vec = s_feats_vec.data(); - for (std::size_t i = 0; i < block.Size(); ++i) { - RegTree::FVec &feats = feats_vec[i]; - if (feats.Size() == 0) { - feats.Init(n_features); - } - this->Fill(block.begin() + i, &feats); - } - } - // Clear the feature vector. - static void FVecDrop(common::Span s_feats) { - auto p_feats = s_feats.data(); - for (size_t i = 0, n = s_feats.size(); i < n; ++i) { - p_feats[i].Drop(); - } - } -}; - -template -class SparsePageView : public DataToFeatVec> { - EncAccessor acc_; - HostSparsePageView const view_; - - public: - bst_idx_t const base_rowid; - - SparsePageView(HostSparsePageView const p, bst_idx_t base_rowid, EncAccessor acc) - : acc_{std::move(acc)}, view_{p}, base_rowid{base_rowid} {} - [[nodiscard]] std::size_t Size() const { return view_.Size(); } - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto p_data = view_[ridx].data(); - - for (std::size_t i = 0, n = view_[ridx].size(); i < n; ++i) { - auto const &entry = p_data[i]; - out[entry.index] = acc_(entry); - } - - return view_[ridx].size(); - } -}; - -template -class GHistIndexMatrixView : public DataToFeatVec> { - private: - GHistIndexMatrix const &page_; - EncAccessor acc_; - common::Span ft_; - - std::vector const &ptrs_; - std::vector const &mins_; - std::vector const &values_; - common::ColumnMatrix const &columns_; - - public: - bst_idx_t const base_rowid; - - public: - GHistIndexMatrixView(GHistIndexMatrix const &_page, EncAccessor acc, - common::Span ft) - : page_{_page}, - acc_{std::move(acc)}, - ft_{ft}, - ptrs_{_page.cut.Ptrs()}, - mins_{_page.cut.MinValues()}, - values_{_page.cut.Values()}, - columns_{page_.Transpose()}, - base_rowid{_page.base_rowid} {} - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto gridx = ridx + this->base_rowid; - auto n_features = page_.Features(); - - bst_idx_t n_non_missings = 0; - if (page_.IsDense()) { - common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { - using T = decltype(t); - auto ptr = this->page_.index.template data(); - auto rbeg = this->page_.row_ptr[ridx]; - for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - bst_bin_t bin_idx; - float fvalue; - if (common::IsCat(ft_, fidx)) { - bin_idx = page_.GetGindex(gridx, fidx); - fvalue = this->values_[bin_idx]; - } else { - bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; - fvalue = - common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); - } - out[fidx] = acc_(fvalue, fidx); - } - }); - n_non_missings += n_features; - } else { - for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - float fvalue = std::numeric_limits::quiet_NaN(); - bool is_cat = common::IsCat(ft_, fidx); - if (columns_.GetColumnType(fidx) == common::kSparseColumn) { - // Special handling for extremely sparse data. Just binary search. - auto bin_idx = page_.GetGindex(gridx, fidx); - if (bin_idx != -1) { - if (is_cat) { - fvalue = values_[bin_idx]; - } else { - fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, - bin_idx); - } - } - } else { - fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); - } - if (!common::CheckNAN(fvalue)) { - out[fidx] = acc_(fvalue, fidx); - n_non_missings++; - } - } - } - return n_non_missings; - } - - [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } -}; - -template -class AdapterView : public DataToFeatVec> { - Adapter const *adapter_; - float missing_; - EncAccessor acc_; - - public: - explicit AdapterView(Adapter const *adapter, float missing, EncAccessor acc) - : adapter_{adapter}, missing_{missing}, acc_{std::move(acc)} {} - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto const &batch = adapter_->Value(); - auto row = batch.GetLine(ridx); - bst_idx_t n_non_missings = 0; - for (size_t c = 0; c < row.Size(); ++c) { - auto e = row.GetElement(c); - if (missing_ != e.value && !common::CheckNAN(e.value)) { - auto fvalue = this->acc_(e); - out[e.column_idx] = fvalue; - n_non_missings++; - } - } - return n_non_missings; - } - - [[nodiscard]] bst_idx_t Size() const { return adapter_->NumRows(); } - - bst_idx_t const static base_rowid = 0; // NOLINT -}; - // Ordinal re-coder. struct EncAccessorPolicy { private: diff --git a/src/predictor/data_accessor.h b/src/predictor/data_accessor.h new file mode 100644 index 000000000000..7c07da7b1fb5 --- /dev/null +++ b/src/predictor/data_accessor.h @@ -0,0 +1,190 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "../common/categorical.h" // for IsCat +#include "../common/column_matrix.h" // for ColumnMatrix +#include "../common/common.h" // for Range1d +#include "../common/hist_util.h" // for DispatchBinType, HistogramCuts +#include "../common/math.h" // for CheckNAN +#include "../data/cat_container.h" // for NoOpAccessor +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "xgboost/data.h" // for HostSparsePageView +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree::FVec + +namespace xgboost::predictor { +// Convert a single sample in batch view to FVec. +template +struct DataToFeatVec { + void Fill(bst_idx_t ridx, RegTree::FVec* p_feats) const { + auto& feats = *p_feats; + auto n_valid = static_cast(this)->DoFill(ridx, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + } + + // Fill the data into the feature vector. + void FVecFill(common::Range1d const& block, bst_feature_t n_features, + common::Span s_feats_vec) const { + auto feats_vec = s_feats_vec.data(); + for (std::size_t i = 0; i < block.Size(); ++i) { + RegTree::FVec& feats = feats_vec[i]; + if (feats.Size() == 0) { + feats.Init(n_features); + } + this->Fill(block.begin() + i, &feats); + } + } + // Clear the feature vector. + static void FVecDrop(common::Span s_feats) { + auto p_feats = s_feats.data(); + for (size_t i = 0, n = s_feats.size(); i < n; ++i) { + p_feats[i].Drop(); + } + } +}; + +template +class SparsePageView : public DataToFeatVec> { + EncAccessor acc_; + HostSparsePageView const view_; + + public: + bst_idx_t const base_rowid; + + SparsePageView(HostSparsePageView const p, bst_idx_t base_rowid, EncAccessor acc) + : acc_{std::move(acc)}, view_{p}, base_rowid{base_rowid} {} + + [[nodiscard]] std::size_t Size() const { return view_.Size(); } + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto p_data = view_[ridx].data(); + + for (std::size_t i = 0, n = view_[ridx].size(); i < n; ++i) { + auto const& entry = p_data[i]; + out[entry.index] = acc_(entry); + } + + return view_[ridx].size(); + } +}; + +template +class GHistIndexMatrixView : public DataToFeatVec> { + private: + GHistIndexMatrix const& page_; + EncAccessor acc_; + common::Span ft_; + + std::vector const& ptrs_; + std::vector const& mins_; + std::vector const& values_; + common::ColumnMatrix const& columns_; + + public: + bst_idx_t const base_rowid; + + public: + GHistIndexMatrixView(GHistIndexMatrix const& page, EncAccessor acc, + common::Span ft) + : page_{page}, + acc_{std::move(acc)}, + ft_{ft}, + ptrs_{page.cut.Ptrs()}, + mins_{page.cut.MinValues()}, + values_{page.cut.Values()}, + columns_{page.Transpose()}, + base_rowid{page.base_rowid} {} + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto gridx = ridx + this->base_rowid; + auto n_features = page_.Features(); + + bst_idx_t n_non_missings = 0; + if (page_.IsDense()) { + common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { + using T = decltype(t); + auto ptr = this->page_.index.template data(); + auto rbeg = this->page_.row_ptr[ridx]; + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + bst_bin_t bin_idx; + float fvalue; + if (common::IsCat(ft_, fidx)) { + bin_idx = page_.GetGindex(gridx, fidx); + fvalue = this->values_[bin_idx]; + } else { + bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; + fvalue = + common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); + } + out[fidx] = acc_(fvalue, fidx); + } + }); + n_non_missings += n_features; + } else { + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + float fvalue = std::numeric_limits::quiet_NaN(); + bool is_cat = common::IsCat(ft_, fidx); + if (columns_.GetColumnType(fidx) == common::kSparseColumn) { + // Special handling for extremely sparse data. Just binary search. + auto bin_idx = page_.GetGindex(gridx, fidx); + if (bin_idx != -1) { + if (is_cat) { + fvalue = values_[bin_idx]; + } else { + fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, + bin_idx); + } + } + } else { + fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); + } + if (!common::CheckNAN(fvalue)) { + out[fidx] = acc_(fvalue, fidx); + n_non_missings++; + } + } + } + return n_non_missings; + } + + [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } +}; + +template +class AdapterView : public DataToFeatVec> { + Adapter const* adapter_; + float missing_; + EncAccessor acc_; + + public: + explicit AdapterView(Adapter const* adapter, float missing, EncAccessor acc) + : adapter_{adapter}, missing_{missing}, acc_{std::move(acc)} {} + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto const& batch = adapter_->Value(); + auto row = batch.GetLine(ridx); + bst_idx_t n_non_missings = 0; + for (size_t c = 0; c < row.Size(); ++c) { + auto e = row.GetElement(c); + if (missing_ != e.value && !common::CheckNAN(e.value)) { + auto fvalue = this->acc_(e); + out[e.column_idx] = fvalue; + n_non_missings++; + } + } + return n_non_missings; + } + + [[nodiscard]] bst_idx_t Size() const { return adapter_->NumRows(); } + + bst_idx_t const static base_rowid = 0; // NOLINT +}; +} // namespace xgboost::predictor diff --git a/src/predictor/gpu_data_accessor.cuh b/src/predictor/gpu_data_accessor.cuh new file mode 100644 index 000000000000..ab068d54164c --- /dev/null +++ b/src/predictor/gpu_data_accessor.cuh @@ -0,0 +1,113 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#pragma once + +#include +#include +#include + +#include "../common/categorical.h" // for IsCat +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for Entry, SparsePage +#include "xgboost/span.h" // for Span + +namespace xgboost::predictor { +struct SparsePageView { + common::Span d_data; + common::Span d_row_ptr; + bst_feature_t num_features; + + SparsePageView() = default; + explicit SparsePageView(Context const* ctx, SparsePage const& page, bst_feature_t n_features) + : d_data{[&] { + page.data.SetDevice(ctx->Device()); + return page.data.ConstDeviceSpan(); + }()}, + d_row_ptr{[&] { + page.offset.SetDevice(ctx->Device()); + return page.offset.ConstDeviceSpan(); + }()}, + num_features{n_features} {} + + [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { + // Binary search + auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; + auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; + if (end_ptr - begin_ptr == this->NumCols()) { + // Bypass span check for dense data + return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; + } + common::Span::iterator previous_middle; + while (end_ptr != begin_ptr) { + auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; + if (middle == previous_middle) { + break; + } else { + previous_middle = middle; + } + + if (middle->index == fidx) { + return middle->fvalue; + } else if (middle->index < fidx) { + begin_ptr = middle; + } else { + end_ptr = middle; + } + } + // Value is missing + return std::numeric_limits::quiet_NaN(); + } + + [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } + [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } +}; + +template +struct SparsePageLoaderNoShared { + public: + using SupportShmemLoad = std::false_type; + + SparsePageView data; + EncAccessor acc; + + template + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { + return acc(data.GetElement(ridx, fidx), fidx); + } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } +}; + +template +struct EllpackLoader { + public: + using SupportShmemLoad = std::false_type; + + Accessor matrix; + EncAccessor acc; + + XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t /*n_features*/, + bst_idx_t /*n_samples*/, float /*missing*/, EncAccessor&& acc) + : matrix{std::move(m)}, acc{std::forward(acc)} {} + + [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { + auto gidx = matrix.template GetBinIndex(ridx, fidx); + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + if (common::IsCat(matrix.feature_types, fidx)) { + return this->acc(matrix.gidx_fvalue_map[gidx], fidx); + } + // The gradient index needs to be shifted by one as min values are not included in the + // cuts. + if (gidx == matrix.feature_segments[fidx]) { + return matrix.min_fvalue[fidx]; + } + return matrix.gidx_fvalue_map[gidx - 1]; + } + + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; } +}; +} // namespace xgboost::predictor diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 0a6e8329acbf..81a515532404 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -25,9 +25,10 @@ #include "../data/proxy_dmatrix.cuh" // for DispatchAny #include "../data/proxy_dmatrix.h" #include "../gbm/gbtree_model.h" -#include "../interpretability/shap.h" #include "../tree/tree_view.h" #include "gbtree_view.h" // for GBTreeModelView +#include "gpu_data_accessor.cuh" +#include "interpretability/shap.h" #include "predict_fn.h" #include "utils.h" // for CheckProxyDMatrix #include "xgboost/data.h" @@ -42,55 +43,6 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor); using cuda_impl::StaticBatch; -struct SparsePageView { - common::Span d_data; - common::Span d_row_ptr; - bst_feature_t num_features; - - SparsePageView() = default; - explicit SparsePageView(Context const* ctx, SparsePage const& page, bst_feature_t n_features) - : d_data{[&] { - page.data.SetDevice(ctx->Device()); - return page.data.ConstDeviceSpan(); - }()}, - d_row_ptr{[&] { - page.offset.SetDevice(ctx->Device()); - return page.offset.ConstDeviceSpan(); - }()}, - num_features{n_features} {} - - [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { - // Binary search - auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; - auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; - if (end_ptr - begin_ptr == this->NumCols()) { - // Bypass span check for dense data - return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; - } - common::Span::iterator previous_middle; - while (end_ptr != begin_ptr) { - auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; - if (middle == previous_middle) { - break; - } else { - previous_middle = middle; - } - - if (middle->index == fidx) { - return middle->fvalue; - } else if (middle->index < fidx) { - begin_ptr = middle; - } else { - end_ptr = middle; - } - } - // Value is missing - return std::numeric_limits::quiet_NaN(); - } - [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } - [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } -}; - template struct SparsePageLoader { public: @@ -135,36 +87,6 @@ struct SparsePageLoader { } }; -template -struct EllpackLoader { - public: - using SupportShmemLoad = std::false_type; - - Accessor matrix; - EncAccessor acc; - - XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t /*n_features*/, - bst_idx_t /*n_samples*/, float /*missing*/, EncAccessor&& acc) - : matrix{std::move(m)}, acc{std::forward(acc)} {} - [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { - auto gidx = matrix.template GetBinIndex(ridx, fidx); - if (gidx == -1) { - return std::numeric_limits::quiet_NaN(); - } - if (common::IsCat(matrix.feature_types, fidx)) { - return this->acc(matrix.gidx_fvalue_map[gidx], fidx); - } - // The gradient index needs to be shifted by one as min values are not included in the - // cuts. - if (gidx == matrix.feature_segments[fidx]) { - return matrix.min_fvalue[fidx]; - } - return matrix.gidx_fvalue_map[gidx - 1]; - } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; } -}; - /** * @brief Use for in-place predict. */ diff --git a/src/interpretability/shap.cc b/src/predictor/interpretability/shap.cc similarity index 82% rename from src/interpretability/shap.cc rename to src/predictor/interpretability/shap.cc index 2f4633ac7bdd..9b9cee6a5ee6 100644 --- a/src/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -4,20 +4,15 @@ #include "shap.h" #include // for fill -#include // for numeric_limits #include // for remove_const_t #include // for vector -#include "../common/categorical.h" // for IsCat -#include "../common/column_matrix.h" // for ColumnMatrix -#include "../common/hist_util.h" // for DispatchBinType, HistogramCuts -#include "../common/math.h" // for CheckNAN -#include "../common/threading_utils.h" // for ParallelFor -#include "../data/gradient_index.h" // for GHistIndexMatrix -#include "../gbm/gbtree_model.h" // for GBTreeModel -#include "../predictor/predict_fn.h" // for GetTreeLimit -#include "../predictor/treeshap.h" // for CalculateContributions -#include "../tree/tree_view.h" // for ScalarTreeView +#include "../../common/threading_utils.h" // for ParallelFor +#include "../../gbm/gbtree_model.h" // for GBTreeModel +#include "../../tree/tree_view.h" // for ScalarTreeView +#include "../data_accessor.h" // for GHistIndexMatrixView +#include "../predict_fn.h" // for GetTreeLimit +#include "../treeshap.h" // for CalculateContributions #include "dmlc/omp.h" // for omp_get_thread_num #include "xgboost/base.h" // for bst_omp_uint #include "xgboost/logging.h" // for CHECK @@ -64,83 +59,6 @@ void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVe CHECK_EQ(out_contribs->size(), feats.Size() + 1); CalculateContributionsApprox(tree, feats, mean_values, out_contribs->data()); } - -class GHistIndexMatrixView { - private: - GHistIndexMatrix const &page_; - common::Span ft_; - - std::vector const &ptrs_; - std::vector const &mins_; - std::vector const &values_; - common::ColumnMatrix const &columns_; - - public: - bst_idx_t const base_rowid; - - public: - GHistIndexMatrixView(GHistIndexMatrix const &page, common::Span ft) - : page_{page}, - ft_{ft}, - ptrs_{page.cut.Ptrs()}, - mins_{page.cut.MinValues()}, - values_{page.cut.Values()}, - columns_{page.Transpose()}, - base_rowid{page.base_rowid} {} - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto gridx = ridx + this->base_rowid; - auto n_features = page_.Features(); - - bst_idx_t n_non_missings = 0; - if (page_.IsDense()) { - common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { - using T = decltype(t); - auto ptr = this->page_.index.template data(); - auto rbeg = this->page_.row_ptr[ridx]; - for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - bst_bin_t bin_idx; - float fvalue; - if (common::IsCat(ft_, fidx)) { - bin_idx = page_.GetGindex(gridx, fidx); - fvalue = this->values_[bin_idx]; - } else { - bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; - fvalue = - common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); - } - out[fidx] = fvalue; - } - }); - n_non_missings += n_features; - } else { - for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - float fvalue = std::numeric_limits::quiet_NaN(); - bool is_cat = common::IsCat(ft_, fidx); - if (columns_.GetColumnType(fidx) == common::kSparseColumn) { - auto bin_idx = page_.GetGindex(gridx, fidx); - if (bin_idx != -1) { - if (is_cat) { - fvalue = values_[bin_idx]; - } else { - fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, - bin_idx); - } - } - } else { - fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); - } - if (!common::CheckNAN(fvalue)) { - out[fidx] = fvalue; - n_non_missings++; - } - } - } - return n_non_missings; - } - - [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } -}; } // namespace namespace cpu_impl { @@ -217,7 +135,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou } else { auto ft = p_fmat->Info().feature_types.ConstHostVector(); for (auto const &page : p_fmat->GetBatches(ctx, {})) { - GHistIndexMatrixView view{page, ft}; + predictor::GHistIndexMatrixView view{page, NoOpAccessor{}, ft}; common::ParallelFor(view.Size(), n_threads, [&](auto i) { auto tid = omp_get_thread_num(); auto &feats = feats_tloc[tid]; @@ -324,7 +242,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, } else { auto ft = p_fmat->Info().feature_types.ConstHostVector(); for (auto const &page : p_fmat->GetBatches(ctx, {})) { - GHistIndexMatrixView view{page, ft}; + predictor::GHistIndexMatrixView view{page, NoOpAccessor{}, ft}; common::ParallelFor(view.Size(), n_threads, [&](auto i) { auto tid = omp_get_thread_num(); auto &feats = feats_tloc[tid]; diff --git a/src/interpretability/shap.cu b/src/predictor/interpretability/shap.cu similarity index 81% rename from src/interpretability/shap.cu rename to src/predictor/interpretability/shap.cu index 2f02234205c8..5123a7c0e3e5 100644 --- a/src/interpretability/shap.cu +++ b/src/predictor/interpretability/shap.cu @@ -20,21 +20,22 @@ #include #include -#include "../common/categorical.h" -#include "../common/common.h" -#include "../common/cuda_context.cuh" // for CUDAContext -#include "../common/cuda_rt_utils.h" // for SetDevice -#include "../common/device_helpers.cuh" -#include "../common/error_msg.h" -#include "../common/nvtx_utils.h" -#include "../data/batch_utils.h" // for StaticBatch -#include "../data/cat_container.cuh" // for EncPolicy, MakeCatAccessor -#include "../data/cat_container.h" // for NoOpAccessor -#include "../data/ellpack_page.cuh" -#include "../gbm/gbtree_model.h" -#include "../predictor/gbtree_view.h" -#include "../predictor/predict_fn.h" // for GetTreeLimit -#include "../tree/tree_view.h" +#include "../../common/categorical.h" +#include "../../common/common.h" +#include "../../common/cuda_context.cuh" // for CUDAContext +#include "../../common/cuda_rt_utils.h" // for SetDevice +#include "../../common/device_helpers.cuh" +#include "../../common/error_msg.h" +#include "../../common/nvtx_utils.h" +#include "../../data/batch_utils.h" // for StaticBatch +#include "../../data/cat_container.cuh" // for EncPolicy, MakeCatAccessor +#include "../../data/cat_container.h" // for NoOpAccessor +#include "../../data/ellpack_page.cuh" +#include "../../gbm/gbtree_model.h" +#include "../../tree/tree_view.h" +#include "../gbtree_view.h" +#include "../gpu_data_accessor.cuh" +#include "../predict_fn.h" // for GetTreeLimit #include "shap.h" #include "xgboost/data.h" #include "xgboost/host_device_vector.h" @@ -44,102 +45,12 @@ namespace xgboost::interpretability::cuda_impl { namespace { +using predictor::EllpackLoader; using predictor::GBTreeModelView; +using predictor::SparsePageLoaderNoShared; +using predictor::SparsePageView; using ::xgboost::cuda_impl::StaticBatch; -struct SparsePageView { - common::Span d_data; - common::Span d_row_ptr; - bst_feature_t num_features; - - SparsePageView() = default; - explicit SparsePageView(Context const* ctx, SparsePage const& page, bst_feature_t n_features) - : d_data{[&] { - page.data.SetDevice(ctx->Device()); - return page.data.ConstDeviceSpan(); - }()}, - d_row_ptr{[&] { - page.offset.SetDevice(ctx->Device()); - return page.offset.ConstDeviceSpan(); - }()}, - num_features{n_features} {} - - [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { - auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; - auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; - if (end_ptr - begin_ptr == this->NumCols()) { - return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; - } - common::Span::iterator previous_middle; - while (end_ptr != begin_ptr) { - auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; - if (middle == previous_middle) { - break; - } else { - previous_middle = middle; - } - - if (middle->index == fidx) { - return middle->fvalue; - } else if (middle->index < fidx) { - begin_ptr = middle; - } else { - end_ptr = middle; - } - } - return std::numeric_limits::quiet_NaN(); - } - [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } - [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } -}; - -template -struct ShapSparsePageLoader { - public: - using SupportShmemLoad = std::false_type; - - SparsePageView data; - EncAccessor acc; - - template - [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { - auto fvalue = data.GetElement(ridx, fidx); - return acc(fvalue, fidx); - } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } -}; - -template -struct EllpackLoader { - public: - using SupportShmemLoad = std::false_type; - - Accessor matrix; - EncAccessor acc; - - XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t n_features, - bst_idx_t num_rows, float missing, EncAccessor&& acc) - : matrix{std::move(m)}, acc{std::forward(acc)} {} - - template - [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { - auto gidx = matrix.template GetBinIndex(ridx, fidx); - if (gidx == -1) { - return std::numeric_limits::quiet_NaN(); - } - if (common::IsCat(matrix.feature_types, fidx)) { - return this->acc(matrix.gidx_fvalue_map[gidx], fidx); - } - if (gidx == matrix.feature_segments[fidx]) { - return matrix.min_fvalue[fidx]; - } - return matrix.gidx_fvalue_map[gidx - 1]; - } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return matrix.n_rows; } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return matrix.NumFeatures(); } -}; - using TreeViewVar = cuda::std::variant; struct CopyViews { @@ -369,7 +280,7 @@ class LaunchConfig { if (p_fmat->PageExists()) { for (auto& page : p_fmat->GetBatches()) { SparsePageView batch{ctx_, page, n_features_}; - auto loader = ShapSparsePageLoader{batch, acc}; + auto loader = SparsePageLoaderNoShared{batch, acc}; fn(std::move(loader), page.base_rowid); } } else { diff --git a/src/interpretability/shap.h b/src/predictor/interpretability/shap.h similarity index 100% rename from src/interpretability/shap.h rename to src/predictor/interpretability/shap.h diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index d3bea9d0c4fe..1a516b01490e 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -19,7 +19,7 @@ #include "../../../src/common/param_array.h" #include "../../../src/gbm/gbtree_model.h" -#include "../../../src/interpretability/shap.h" +#include "../../../src/predictor/interpretability/shap.h" #include "../helpers.h" namespace xgboost { From 4deae843da55bd9d4073abf98c7b95c5b498d4c0 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 15 Feb 2026 09:30:16 -0800 Subject: [PATCH 08/17] Undo use span for tree_weights --- include/xgboost/predictor.h | 5 ++--- src/gbm/gbtree.cc | 7 +++---- src/gbm/gbtree.h | 4 ++-- src/predictor/cpu_predictor.cc | 4 ++-- src/predictor/gpu_predictor.cu | 4 ++-- src/predictor/interpretability/shap.cc | 20 ++++++++++---------- src/predictor/interpretability/shap.cu | 10 +++++----- src/predictor/interpretability/shap.h | 19 +++++++++---------- 8 files changed, 35 insertions(+), 38 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 695abca5c35c..a4c490dc0848 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -12,7 +12,6 @@ #include #include #include -#include #include // for function #include // for shared_ptr @@ -157,14 +156,14 @@ class Predictor { virtual void PredictContribution(DMatrix* dmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end = 0, - common::Span tree_weights = {}, + std::vector const* tree_weights = nullptr, bool approximate = false, int condition = 0, unsigned condition_feature = 0) const = 0; virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end = 0, - common::Span tree_weights = {}, + std::vector const* tree_weights = nullptr, bool approximate = false) const = 0; /** diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index f621c6a4e20c..594458dbbb93 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -887,8 +887,8 @@ class Dart : public GBTree { bst_layer_t layer_begin, bst_layer_t layer_end, bool approximate) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); - cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, - common::Span{weight_drop_}, approximate); + cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, &weight_drop_, + approximate); } void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, @@ -896,8 +896,7 @@ class Dart : public GBTree { bool approximate) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end, - common::Span{weight_drop_}, - approximate); + &weight_drop_, approximate); } protected: diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 024a0d33d3f6..7d5c425ad6f2 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -302,7 +302,7 @@ class GBTree : public GradientBooster { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: [0, " "n_iteration), using model slicing instead."; - this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, {}, + this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, nullptr, approximate); } @@ -313,7 +313,7 @@ class GBTree : public GradientBooster { CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: [0, " "n_iteration), using model slicing instead."; this->GetPredictor(false)->PredictInteractionContributions(p_fmat, out_contribs, model_, - tree_end, {}, approximate); + tree_end, nullptr, approximate); } [[nodiscard]] std::vector DumpModel(const FeatureMap& fmap, bool with_stats, diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index b2fb8a13adf6..8be90dc5d2f3 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -835,7 +835,7 @@ class CPUPredictor : public Predictor { void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, const gbm::GBTreeModel &model, bst_tree_t ntree_limit, - common::Span tree_weights, bool approximate, int condition, + std::vector const *tree_weights, bool approximate, int condition, unsigned condition_feature) const override { if (approximate) { interpretability::ApproxFeatureImportance(this->ctx_, p_fmat, out_contribs, model, @@ -848,7 +848,7 @@ class CPUPredictor : public Predictor { void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t ntree_limit, - common::Span tree_weights, + std::vector const *tree_weights, bool approximate) const override { interpretability::ShapInteractionValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 81a515532404..2875976393d3 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -737,7 +737,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, bst_tree_t tree_end, - common::Span tree_weights, bool approximate, int, + std::vector const* tree_weights, bool approximate, int, unsigned) const override { xgboost_NVTX_FN_RANGE(); if (approximate) { @@ -749,7 +749,7 @@ class GPUPredictor : public xgboost::Predictor { void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, - common::Span tree_weights, + std::vector const* tree_weights, bool approximate) const override { xgboost_NVTX_FN_RANGE(); if (approximate) { diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index 9b9cee6a5ee6..b8c0559dbca2 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -20,11 +20,11 @@ namespace xgboost::interpretability { namespace { -void ValidateTreeWeights(common::Span tree_weights, bst_tree_t tree_end) { - if (tree_weights.empty()) { +void ValidateTreeWeights(std::vector const *tree_weights, bst_tree_t tree_end) { + if (tree_weights == nullptr) { return; } - CHECK_GE(tree_weights.size(), static_cast(tree_end)); + CHECK_GE(tree_weights->size(), static_cast(tree_end)); } float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx, @@ -64,7 +64,7 @@ void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVe namespace cpu_impl { void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - common::Span tree_weights, int condition, unsigned condition_feature) { + std::vector const *tree_weights, int condition, unsigned condition_feature) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; @@ -119,7 +119,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou condition, condition_feature); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } } if (base_margin.Size() != 0) { @@ -158,7 +158,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou condition, condition_feature); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } } if (base_margin.Size() != 0) { @@ -176,7 +176,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, common::Span tree_weights) { + bst_tree_t tree_end, std::vector const *tree_weights) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict contribution support for column-wise data split is not yet implemented."; @@ -226,7 +226,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } } if (base_margin.Size() != 0) { @@ -264,7 +264,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); for (size_t ci = 0; ci < ncolumns; ++ci) { p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights.empty() ? 1 : tree_weights[j]); + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } } if (base_margin.Size() != 0) { @@ -282,7 +282,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, common::Span tree_weights, + bst_tree_t tree_end, std::vector const *tree_weights, bool approximate) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict interaction contribution" << MTNotImplemented(); diff --git a/src/predictor/interpretability/shap.cu b/src/predictor/interpretability/shap.cu index 5123a7c0e3e5..c78a6719ff45 100644 --- a/src/predictor/interpretability/shap.cu +++ b/src/predictor/interpretability/shap.cu @@ -325,11 +325,11 @@ void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, bst_tree_t tree_end, - common::Span tree_weights, int, unsigned) { + std::vector const* tree_weights, int, unsigned) { xgboost_NVTX_FN_RANGE(); StringView not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; - if (!tree_weights.empty()) { + if (tree_weights != nullptr) { LOG(FATAL) << "Dart booster feature " << not_implemented; } CHECK(!p_fmat->Info().IsColumnSplit()) @@ -380,14 +380,14 @@ void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* ou void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, - bst_tree_t tree_end, common::Span tree_weights, + bst_tree_t tree_end, std::vector const* tree_weights, bool approximate) { xgboost_NVTX_FN_RANGE(); std::string not_implemented{"contribution is not implemented in GPU predictor, use cpu instead."}; if (approximate) { LOG(FATAL) << "Approximated " << not_implemented; } - if (!tree_weights.empty()) { + if (tree_weights != nullptr) { LOG(FATAL) << "Dart booster feature " << not_implemented; } dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); @@ -438,7 +438,7 @@ void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, } void ApproxFeatureImportance(Context const* ctx, DMatrix*, HostDeviceVector*, - gbm::GBTreeModel const&, bst_tree_t, common::Span) { + gbm::GBTreeModel const&, bst_tree_t, std::vector const*) { StringView not_implemented{ "contribution is not implemented in the GPU predictor, use CPU instead."}; LOG(FATAL) << "Approximated " << not_implemented; diff --git a/src/predictor/interpretability/shap.h b/src/predictor/interpretability/shap.h index 3de35afedba2..2c7c71455e60 100644 --- a/src/predictor/interpretability/shap.h +++ b/src/predictor/interpretability/shap.h @@ -8,7 +8,6 @@ #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/span.h" // for Span namespace xgboost::gbm { struct GBTreeModel; @@ -18,15 +17,15 @@ namespace xgboost::interpretability { namespace cpu_impl { void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - common::Span tree_weights, int condition, unsigned condition_feature); + std::vector const *tree_weights, int condition, unsigned condition_feature); void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, common::Span tree_weights); + bst_tree_t tree_end, std::vector const *tree_weights); void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, common::Span tree_weights, + bst_tree_t tree_end, std::vector const *tree_weights, bool approximate); } // namespace cpu_impl @@ -34,20 +33,20 @@ void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, namespace cuda_impl { void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - common::Span tree_weights, int condition, unsigned condition_feature); + std::vector const *tree_weights, int condition, unsigned condition_feature); void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, common::Span tree_weights); + bst_tree_t tree_end, std::vector const *tree_weights); void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, common::Span tree_weights, + bst_tree_t tree_end, std::vector const *tree_weights, bool approximate); } // namespace cuda_impl #endif // defined(XGBOOST_USE_CUDA) inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - common::Span tree_weights, int condition, + std::vector const *tree_weights, int condition, unsigned condition_feature) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { @@ -63,7 +62,7 @@ inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - common::Span tree_weights) { + std::vector const *tree_weights) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { cuda_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); @@ -76,7 +75,7 @@ inline void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, inline void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end, - common::Span tree_weights, bool approximate) { + std::vector const *tree_weights, bool approximate) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { cuda_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, From 7bca3deb62211ccfc9eea9226fb1b24f8a3d126d Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 15 Feb 2026 09:43:34 -0800 Subject: [PATCH 09/17] Signed comparison error --- src/predictor/interpretability/shap.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index b8c0559dbca2..1f06a5b4773e 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -72,6 +72,8 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou // number of valid trees tree_end = predictor::GetTreeLimit(model.trees, tree_end); ValidateTreeWeights(tree_weights, tree_end); + CHECK_GE(tree_end, 0); + auto const n_trees = static_cast(tree_end); size_t const ncolumns = model.learner_model_param->num_feature + 1; // allocate space for (number of features + bias) times the number of rows std::vector &contribs = out_contribs->HostVector(); @@ -79,8 +81,8 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou // make sure contributions is zeroed, we could be reusing a previously allocated one std::fill(contribs.begin(), contribs.end(), 0); // initialize tree node mean values - std::vector> mean_values(tree_end); - for (bst_omp_uint i = 0; i < tree_end; ++i) { + std::vector> mean_values(n_trees); + for (bst_omp_uint i = 0; i < n_trees; ++i) { FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); } @@ -183,12 +185,14 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, MetaInfo const &info = p_fmat->Info(); tree_end = predictor::GetTreeLimit(model.trees, tree_end); ValidateTreeWeights(tree_weights, tree_end); + CHECK_GE(tree_end, 0); + auto const n_trees = static_cast(tree_end); size_t const ncolumns = model.learner_model_param->num_feature + 1; std::vector &contribs = out_contribs->HostVector(); contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); std::fill(contribs.begin(), contribs.end(), 0); - std::vector> mean_values(tree_end); - for (bst_omp_uint i = 0; i < tree_end; ++i) { + std::vector> mean_values(n_trees); + for (bst_omp_uint i = 0; i < n_trees; ++i) { FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); } From 37d7924704566ec4f5ceae7277bc7337806579a2 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 15 Feb 2026 09:56:33 -0800 Subject: [PATCH 10/17] R package build --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win.in | 1 + 2 files changed, 2 insertions(+) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index ecf186176465..81cbb8434737 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -85,6 +85,7 @@ OBJECTS= \ $(PKGROOT)/src/data/iterative_dmatrix.o \ $(PKGROOT)/src/predictor/predictor.o \ $(PKGROOT)/src/predictor/cpu_predictor.o \ + $(PKGROOT)/src/predictor/interpretability/shap.o \ $(PKGROOT)/src/predictor/treeshap.o \ $(PKGROOT)/src/tree/constraints.o \ $(PKGROOT)/src/tree/param.o \ diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in index c25eb5f4212a..352c8a45922a 100644 --- a/R-package/src/Makevars.win.in +++ b/R-package/src/Makevars.win.in @@ -84,6 +84,7 @@ OBJECTS= \ $(PKGROOT)/src/data/iterative_dmatrix.o \ $(PKGROOT)/src/predictor/predictor.o \ $(PKGROOT)/src/predictor/cpu_predictor.o \ + $(PKGROOT)/src/predictor/interpretability/shap.o \ $(PKGROOT)/src/predictor/treeshap.o \ $(PKGROOT)/src/tree/constraints.o \ $(PKGROOT)/src/tree/param.o \ From a8cf661b3416b322211e2bb78c9a27abd3680fcb Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 16 Feb 2026 02:30:04 -0800 Subject: [PATCH 11/17] Copilot review --- src/predictor/cpu_predictor.cc | 44 +++++++++---------- src/predictor/interpretability/shap.cc | 20 ++++----- src/predictor/interpretability/shap.cu | 13 ++++-- src/predictor/interpretability/shap.h | 58 +++++++++++++------------- 4 files changed, 70 insertions(+), 65 deletions(-) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 8be90dc5d2f3..ba8b7035afd4 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -8,28 +8,28 @@ #include // for unique_ptr, shared_ptr #include // for vector -#include "../collective/allreduce.h" // for Allreduce -#include "../collective/communicator-inl.h" // for IsDistributed -#include "../common/bitfield.h" // for RBitField8 -#include "../common/column_matrix.h" // for ColumnMatrix -#include "../common/error_msg.h" // for InplacePredictProxy -#include "../common/math.h" // for CheckNAN -#include "../common/threading_utils.h" // for ParallelFor -#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter -#include "../data/cat_container.h" // for CatContainer -#include "../data/gradient_index.h" // for GHistIndexMatrix -#include "../data/proxy_dmatrix.h" // for DMatrixProxy -#include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam -#include "array_tree_layout.h" // for ProcessArrayTree -#include "data_accessor.h" // for GHistIndexMatrixView, SparsePageView -#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG -#include "gbtree_view.h" // for GBTreeModelView -#include "interpretability/shap.h" // for PredictContribution -#include "predict_fn.h" // for GetNextNode, GetNextNodeMulti -#include "utils.h" // for CheckProxyDMatrix -#include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... -#include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for Entry, DMatrix, MetaInfo, SparsePage, Batch... +#include "../collective/allreduce.h" // for Allreduce +#include "../collective/communicator-inl.h" // for IsDistributed +#include "../common/bitfield.h" // for RBitField8 +#include "../common/column_matrix.h" // for ColumnMatrix +#include "../common/error_msg.h" // for InplacePredictProxy +#include "../common/math.h" // for CheckNAN +#include "../common/threading_utils.h" // for ParallelFor +#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter +#include "../data/cat_container.h" // for CatContainer +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "../data/proxy_dmatrix.h" // for DMatrixProxy +#include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam +#include "array_tree_layout.h" // for ProcessArrayTree +#include "data_accessor.h" // for GHistIndexMatrixView, SparsePageView +#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG +#include "gbtree_view.h" // for GBTreeModelView +#include "interpretability/shap.h" // for ShapValues, ApproxFeatureImportance, ShapInteractionValues +#include "predict_fn.h" // for GetNextNode, GetNextNodeMulti +#include "utils.h" // for CheckProxyDMatrix +#include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for Entry, DMatrix, MetaInfo, SparsePage, Batch... #include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/learner.h" // for LearnerModelParam #include "xgboost/linalg.h" // for TensorView, All, VectorView, Tensor diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index 1f06a5b4773e..22cd5203e561 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -71,9 +71,10 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou MetaInfo const &info = p_fmat->Info(); // number of valid trees tree_end = predictor::GetTreeLimit(model.trees, tree_end); - ValidateTreeWeights(tree_weights, tree_end); CHECK_GE(tree_end, 0); - auto const n_trees = static_cast(tree_end); + ValidateTreeWeights(tree_weights, tree_end); + auto const n_trees = static_cast(tree_end); + auto const n_threads = ctx->Threads(); size_t const ncolumns = model.learner_model_param->num_feature + 1; // allocate space for (number of features + bias) times the number of rows std::vector &contribs = out_contribs->HostVector(); @@ -82,15 +83,14 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou std::fill(contribs.begin(), contribs.end(), 0); // initialize tree node mean values std::vector> mean_values(n_trees); - for (bst_omp_uint i = 0; i < n_trees; ++i) { + common::ParallelFor(n_trees, n_threads, [&](auto i) { FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); - } + }); auto const n_groups = model.learner_model_param->num_output_group; CHECK_NE(n_groups, 0); auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); - auto const n_threads = ctx->Threads(); std::vector feats_tloc(n_threads); std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); @@ -184,23 +184,23 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, << "Predict contribution support for column-wise data split is not yet implemented."; MetaInfo const &info = p_fmat->Info(); tree_end = predictor::GetTreeLimit(model.trees, tree_end); - ValidateTreeWeights(tree_weights, tree_end); CHECK_GE(tree_end, 0); - auto const n_trees = static_cast(tree_end); + ValidateTreeWeights(tree_weights, tree_end); + auto const n_trees = static_cast(tree_end); + auto const n_threads = ctx->Threads(); size_t const ncolumns = model.learner_model_param->num_feature + 1; std::vector &contribs = out_contribs->HostVector(); contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); std::fill(contribs.begin(), contribs.end(), 0); std::vector> mean_values(n_trees); - for (bst_omp_uint i = 0; i < n_trees; ++i) { + common::ParallelFor(n_trees, n_threads, [&](auto i) { FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); - } + }); auto const n_groups = model.learner_model_param->num_output_group; CHECK_NE(n_groups, 0); auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); - auto const n_threads = ctx->Threads(); std::vector feats_tloc(n_threads); std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); diff --git a/src/predictor/interpretability/shap.cu b/src/predictor/interpretability/shap.cu index c78a6719ff45..12e70d3b11ca 100644 --- a/src/predictor/interpretability/shap.cu +++ b/src/predictor/interpretability/shap.cu @@ -12,6 +12,7 @@ #include #include // for proclaim_return_type #include // for swap +#include // for variant #include #include #include @@ -106,8 +107,11 @@ struct ShapSplitCondition { if (l.Capacity() > r.Capacity()) { cuda::std::swap(l, r); } - for (size_t i = 0; i < r.Bits().size(); ++i) { - l.Bits()[i] &= r.Bits()[i]; + auto l_bits = l.Bits(); + auto r_bits = r.Bits(); + auto n_bits = l_bits.size() < r_bits.size() ? l_bits.size() : r_bits.size(); + for (size_t i = 0; i < n_bits; ++i) { + l_bits[i] &= r_bits[i]; } return l; } @@ -186,8 +190,9 @@ void ExtractPaths(Context const* ctx, info.begin(), cuda::proclaim_return_type( [=] __device__(PathInfo const& info) { return info.length; })); dh::caching_device_vector path_segments(info.size() + 1); - thrust::exclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size() + 1, - path_segments.begin()); + thrust::fill_n(ctx->CUDACtx()->CTP(), path_segments.begin(), 1, std::size_t{0}); + thrust::inclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size(), + path_segments.begin() + 1); paths->resize(path_segments.back()); diff --git a/src/predictor/interpretability/shap.h b/src/predictor/interpretability/shap.h index 2c7c71455e60..2c32d1d84554 100644 --- a/src/predictor/interpretability/shap.h +++ b/src/predictor/interpretability/shap.h @@ -15,38 +15,38 @@ struct GBTreeModel; namespace xgboost::interpretability { namespace cpu_impl { -void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature); +void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int condition, unsigned condition_feature); -void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights); +void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights); -void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, +void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, bool approximate); } // namespace cpu_impl #if defined(XGBOOST_USE_CUDA) namespace cuda_impl { -void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, unsigned condition_feature); -void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights); -void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, - bst_tree_t tree_end, std::vector const *tree_weights, +void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int condition, unsigned condition_feature); +void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights); +void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, bool approximate); } // namespace cuda_impl #endif // defined(XGBOOST_USE_CUDA) -inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, int condition, +inline void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int condition, unsigned condition_feature) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { @@ -59,10 +59,10 @@ inline void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights) { +inline void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { cuda_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); @@ -72,10 +72,10 @@ inline void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, cpu_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); } -inline void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, - HostDeviceVector *out_contribs, - gbm::GBTreeModel const &model, bst_tree_t tree_end, - std::vector const *tree_weights, bool approximate) { +inline void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, bool approximate) { #if defined(XGBOOST_USE_CUDA) if (ctx->IsCUDA()) { cuda_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, From aa6cdedcb8a38d35b0f8299b88bdd49981e038c5 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 16 Feb 2026 05:10:07 -0800 Subject: [PATCH 12/17] Categorical encoding --- src/predictor/interpretability/shap.cc | 290 +++++++++++++------------ 1 file changed, 156 insertions(+), 134 deletions(-) diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index 22cd5203e561..cec1c68b558d 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -96,83 +96,94 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); auto base_margin = info.base_margin_.View(device); + auto ft = p_fmat->Info().feature_types.ConstHostVector(); - if (p_fmat->PageExists()) { - for (auto const &page : p_fmat->GetBatches()) { - auto view = page.GetView(); - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); - } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = page.base_rowid + i; - feats.Fill(view[i]); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; + auto process_batches = [&](auto acc) { + if (p_fmat->PageExists()) { + for (auto const &page : p_fmat->GetBatches()) { + predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), + condition, condition_feature); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), - condition, condition_feature); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); } } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); + feats.Drop(); + }); + } + } else { + for (auto const &page : p_fmat->GetBatches(ctx, {})) { + predictor::GHistIndexMatrixView view{page, acc, ft}; + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); } - } - feats.Drop(); - }); - } - } else { - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - for (auto const &page : p_fmat->GetBatches(ctx, {})) { - predictor::GHistIndexMatrixView view{page, NoOpAccessor{}, ft}; - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); - } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = view.base_rowid + i; - auto n_valid = view.DoFill(i, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), + condition, condition_feature); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), - condition, condition_feature); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); } } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } - } - feats.Drop(); - }); + feats.Drop(); + }); + } } + }; + + if (model.Cats()->HasCategorical() && p_fmat->Cats()->NeedRecode()) { + auto new_enc = p_fmat->Cats()->HostView(); + auto [acc, mapping] = ::xgboost::cpu_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + process_batches(acc); + } else { + process_batches(NoOpAccessor{}); } } @@ -206,81 +217,92 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); auto base_margin = info.base_margin_.View(device); + auto ft = p_fmat->Info().feature_types.ConstHostVector(); - if (p_fmat->PageExists()) { - for (auto const &page : p_fmat->GetBatches()) { - auto view = page.GetView(); - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); - } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = page.base_rowid + i; - feats.Fill(view[i]); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; + auto process_batches = [&](auto acc) { + if (p_fmat->PageExists()) { + for (auto const &page : p_fmat->GetBatches()) { + predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); } } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); + feats.Drop(); + }); + } + } else { + for (auto const &page : p_fmat->GetBatches(ctx, {})) { + predictor::GHistIndexMatrixView view{page, acc, ft}; + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); } - } - feats.Drop(); - }); - } - } else { - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - for (auto const &page : p_fmat->GetBatches(ctx, {})) { - predictor::GHistIndexMatrixView view{page, NoOpAccessor{}, ft}; - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); - } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = view.base_rowid + i; - auto n_valid = view.DoFill(i, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); } } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } - } - feats.Drop(); - }); + feats.Drop(); + }); + } } + }; + + if (model.Cats()->HasCategorical() && p_fmat->Cats()->NeedRecode()) { + auto new_enc = p_fmat->Cats()->HostView(); + auto [acc, mapping] = ::xgboost::cpu_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + process_batches(acc); + } else { + process_batches(NoOpAccessor{}); } } From c7fc3741bd791e080fad5ceedd8b02cc5fd4928f Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Feb 2026 03:02:14 -0800 Subject: [PATCH 13/17] Update test --- tests/cpp/predictor/test_shap.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index 1a516b01490e..895f5ce4a8e4 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -51,8 +51,9 @@ Args BaseParams(Context const* ctx, std::string objective, std::string max_depth {"device", ctx->IsSycl() ? "cpu" : ctx->DeviceName()}}; } -gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, Args const& model_args, - LearnerModelParam* out_param) { +std::unique_ptr LoadGBTreeModel(Learner* learner, Context const* ctx, + Args const& model_args, + LearnerModelParam* out_param) { Json model{Object{}}; learner->SaveModel(&model); @@ -103,9 +104,9 @@ gbm::GBTreeModel LoadGBTreeModel(Learner* learner, Context const* ctx, Args cons MultiStrategy::kOneOutputPerTree}; out_param->Copy(tmp); - gbm::GBTreeModel gbtree{out_param, ctx}; + auto gbtree = std::make_unique(out_param, ctx); auto const& gbm_obj = get(learner_obj.at("gradient_booster")); - gbtree.LoadModel(gbm_obj.at("model")); + gbtree->LoadModel(gbm_obj.at("model")); return gbtree; } } // namespace @@ -203,12 +204,12 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), model_args, &mparam); HostDeviceVector shap_values; - interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, gbtree, 0, {}, 0, 0); + interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, *gbtree, 0, {}, 0, 0); ASSERT_EQ(shap_values.HostVector().size(), kRows * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_values, margin_predt); HostDeviceVector shap_interactions; - interpretability::ShapInteractionValues(dmat->Ctx(), p_dmat.get(), &shap_interactions, gbtree, 0, + interpretability::ShapInteractionValues(dmat->Ctx(), p_dmat.get(), &shap_interactions, *gbtree, 0, {}, false); ASSERT_EQ(shap_interactions.HostVector().size(), kRows * (kCols + 1) * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_interactions, margin_predt); @@ -278,7 +279,7 @@ TEST(Predictor, ApproxContribsBasic) { auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), args, &mparam); HostDeviceVector approx_contribs; - interpretability::ApproxFeatureImportance(dmat->Ctx(), dmat.get(), &approx_contribs, gbtree, 0, + interpretability::ApproxFeatureImportance(dmat->Ctx(), dmat.get(), &approx_contribs, *gbtree, 0, {}); auto const& h_margin = margin_predt.ConstHostVector(); From 078e53bf5238fa1a6262b64186a075df3ff6fa0f Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Feb 2026 04:10:31 -0800 Subject: [PATCH 14/17] Clang-tidy --- src/predictor/gpu_predictor.cu | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 20ab55baafe0..43d2faa47c23 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -180,11 +180,12 @@ __global__ void PredictLeafKernel(Data data, common::Span d_t common::Span d_out_predictions, bst_tree_t tree_begin, bst_tree_t tree_end, bst_feature_t num_features, bool use_shared, float missing, EncAccessor acc) { + auto n_rows = data.NumRows(); bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; - if (ridx >= data.NumRows()) { + if (ridx >= n_rows) { return; } - Loader loader{std::move(data), use_shared, num_features, data.NumRows(), missing, std::move(acc)}; + Loader loader{std::move(data), use_shared, num_features, n_rows, missing, std::move(acc)}; for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { auto const& d_tree = d_trees[tree_idx - tree_begin]; cuda::std::visit( @@ -207,9 +208,10 @@ __global__ void PredictKernel(Data data, common::Span d_trees common::Span d_tree_groups, bst_feature_t num_features, bool use_shared, bst_target_t n_groups, float missing, EncAccessor acc) { + auto n_rows = data.NumRows(); bst_idx_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - Loader loader{std::move(data), use_shared, num_features, data.NumRows(), missing, std::move(acc)}; - if (global_idx >= data.NumRows()) { + Loader loader{std::move(data), use_shared, num_features, n_rows, missing, std::move(acc)}; + if (global_idx >= n_rows) { return; } @@ -640,10 +642,11 @@ class GPUPredictor : public xgboost::Predictor { bst_idx_t batch_offset = 0; cfg.ForEachBatch(p_fmat, [&](auto&& loader_t, auto&& batch) { using Loader = typename common::GetValueT; + auto n_rows = batch.NumRows(); cfg.template LaunchPredictKernel(std::move(batch), std::numeric_limits::quiet_NaN(), n_features, d_model, acc, batch_offset, out_preds); - batch_offset += batch.NumRows() * model.learner_model_param->OutputLength(); + batch_offset += n_rows * model.learner_model_param->OutputLength(); }); }); } @@ -782,6 +785,7 @@ class GPUPredictor : public xgboost::Predictor { cfg.ForEachBatch(p_fmat, [&](auto&& loader_t, auto&& batch) { using Loader = typename common::GetValueT; using Config = common::GetValueT; + auto n_rows = batch.NumRows(); auto kernel = PredictLeafKernel, Config::HasMissing(), typename Config::EncAccessorT>; cfg.template Launch(kernel, std::move(batch), d_model.Trees(), @@ -790,7 +794,7 @@ class GPUPredictor : public xgboost::Predictor { cfg.UseShared(), std::numeric_limits::quiet_NaN(), std::forward(acc)); - batch_offset += batch.NumRows(); + batch_offset += n_rows; }); }); } From 18de69093a35e9daff8e5729ba7b2caedf689db6 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Feb 2026 04:25:36 -0800 Subject: [PATCH 15/17] Windows build --- src/predictor/interpretability/shap.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index cec1c68b558d..d5162c51ea85 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -99,9 +99,10 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou auto ft = p_fmat->Info().feature_types.ConstHostVector(); auto process_batches = [&](auto acc) { + using AccT = std::decay_t; if (p_fmat->PageExists()) { for (auto const &page : p_fmat->GetBatches()) { - predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; + predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; common::ParallelFor(view.Size(), n_threads, [&](auto i) { auto tid = omp_get_thread_num(); auto &feats = feats_tloc[tid]; @@ -139,7 +140,7 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou } } else { for (auto const &page : p_fmat->GetBatches(ctx, {})) { - predictor::GHistIndexMatrixView view{page, acc, ft}; + predictor::GHistIndexMatrixView view{page, acc, ft}; common::ParallelFor(view.Size(), n_threads, [&](auto i) { auto tid = omp_get_thread_num(); auto &feats = feats_tloc[tid]; @@ -220,9 +221,10 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, auto ft = p_fmat->Info().feature_types.ConstHostVector(); auto process_batches = [&](auto acc) { + using AccT = std::decay_t; if (p_fmat->PageExists()) { for (auto const &page : p_fmat->GetBatches()) { - predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; + predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; common::ParallelFor(view.Size(), n_threads, [&](auto i) { auto tid = omp_get_thread_num(); auto &feats = feats_tloc[tid]; @@ -259,7 +261,7 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, } } else { for (auto const &page : p_fmat->GetBatches(ctx, {})) { - predictor::GHistIndexMatrixView view{page, acc, ft}; + predictor::GHistIndexMatrixView view{page, acc, ft}; common::ParallelFor(view.Size(), n_threads, [&](auto i) { auto tid = omp_get_thread_num(); auto &feats = feats_tloc[tid]; From a72c391bf0c701afd16bce778ed68ae37696f47a Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 Feb 2026 05:36:02 -0800 Subject: [PATCH 16/17] Review comment --- tests/cpp/predictor/test_shap.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index 895f5ce4a8e4..55029871facd 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -204,7 +204,7 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), model_args, &mparam); HostDeviceVector shap_values; - interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, *gbtree, 0, {}, 0, 0); + interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, *gbtree, 0, nullptr, 0, 0); ASSERT_EQ(shap_values.HostVector().size(), kRows * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_values, margin_predt); From b1df6ba9bf2bec19e45317bc34b65ddc4556f740 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 22 Feb 2026 08:36:45 -0800 Subject: [PATCH 17/17] Dispatch on data type --- src/predictor/interpretability/shap.cc | 255 +++++++++---------------- src/predictor/interpretability/shap.cu | 106 ++++------ 2 files changed, 131 insertions(+), 230 deletions(-) diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc index d5162c51ea85..025b47f2abd7 100644 --- a/src/predictor/interpretability/shap.cc +++ b/src/predictor/interpretability/shap.cc @@ -59,6 +59,34 @@ void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVe CHECK_EQ(out_contribs->size(), feats.Size() + 1); CalculateContributionsApprox(tree, feats, mean_values, out_contribs->data()); } + +template +void DispatchByBatchView(Context const *ctx, DMatrix *p_fmat, EncAccessor acc, Fn &&fn) { + using AccT = std::decay_t; + if (p_fmat->PageExists()) { + for (auto const &page : p_fmat->GetBatches()) { + predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; + fn(view); + } + } else { + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &page : p_fmat->GetBatches(ctx, {})) { + predictor::GHistIndexMatrixView view{page, acc, ft}; + fn(view); + } + } +} + +template +void LaunchShap(Context const *ctx, DMatrix *p_fmat, gbm::GBTreeModel const &model, Fn &&fn) { + if (model.Cats()->HasCategorical() && p_fmat->Cats()->NeedRecode()) { + auto new_enc = p_fmat->Cats()->HostView(); + auto [acc, mapping] = ::xgboost::cpu_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + DispatchByBatchView(ctx, p_fmat, acc, fn); + } else { + DispatchByBatchView(ctx, p_fmat, NoOpAccessor{}, fn); + } +} } // namespace namespace cpu_impl { @@ -96,96 +124,45 @@ void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *ou auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); auto base_margin = info.base_margin_.View(device); - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - auto process_batches = [&](auto acc) { - using AccT = std::decay_t; - if (p_fmat->PageExists()) { - for (auto const &page : p_fmat->GetBatches()) { - predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); - } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = view.base_rowid + i; - auto n_valid = view.DoFill(i, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; - } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), - condition, condition_feature); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); - } - } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } - } - feats.Drop(); - }); + auto process_view = [&](auto &&view) { + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); } - } else { - for (auto const &page : p_fmat->GetBatches(ctx, {})) { - predictor::GHistIndexMatrixView view{page, acc, ft}; - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = view.base_rowid + i; - auto n_valid = view.DoFill(i, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; - } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), - condition, condition_feature); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); - } - } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), + condition, condition_feature); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } - feats.Drop(); - }); + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } } - } + feats.Drop(); + }); }; - if (model.Cats()->HasCategorical() && p_fmat->Cats()->NeedRecode()) { - auto new_enc = p_fmat->Cats()->HostView(); - auto [acc, mapping] = ::xgboost::cpu_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); - process_batches(acc); - } else { - process_batches(NoOpAccessor{}); - } + LaunchShap(ctx, p_fmat, model, process_view); } void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, @@ -218,94 +195,44 @@ void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); auto base_margin = info.base_margin_.View(device); - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - auto process_batches = [&](auto acc) { - using AccT = std::decay_t; - if (p_fmat->PageExists()) { - for (auto const &page : p_fmat->GetBatches()) { - predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); - } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = view.base_rowid + i; - auto n_valid = view.DoFill(i, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; - } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); - } - } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } - } - feats.Drop(); - }); + auto process_view = [&](auto &&view) { + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); } - } else { - for (auto const &page : p_fmat->GetBatches(ctx, {})) { - predictor::GHistIndexMatrixView view{page, acc, ft}; - common::ParallelFor(view.Size(), n_threads, [&](auto i) { - auto tid = omp_get_thread_num(); - auto &feats = feats_tloc[tid]; - if (feats.Size() == 0) { - feats.Init(model.learner_model_param->num_feature); + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; } - auto &this_tree_contribs = contribs_tloc[tid]; - auto row_idx = view.base_rowid + i; - auto n_valid = view.DoFill(i, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; - for (bst_tree_t j = 0; j < tree_end; ++j) { - if (h_tree_groups[j] != gid) { - continue; - } - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - auto const sc_tree = model.trees[j]->HostScView(); - CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); - } - } - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); } - feats.Drop(); - }); + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } } - } + feats.Drop(); + }); }; - if (model.Cats()->HasCategorical() && p_fmat->Cats()->NeedRecode()) { - auto new_enc = p_fmat->Cats()->HostView(); - auto [acc, mapping] = ::xgboost::cpu_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); - process_batches(acc); - } else { - process_batches(NoOpAccessor{}); - } + LaunchShap(ctx, p_fmat, model, process_view); } void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, diff --git a/src/predictor/interpretability/shap.cu b/src/predictor/interpretability/shap.cu index c404b29c6dac..e98702b03252 100644 --- a/src/predictor/interpretability/shap.cu +++ b/src/predictor/interpretability/shap.cu @@ -270,58 +270,44 @@ void ExtractPaths(Context const* ctx, }); } -template -class LaunchConfig { - public: - using EncAccessorT = EncAccessor; - - explicit LaunchConfig(Context const* ctx, bst_feature_t n_features) - : ctx_{ctx}, n_features_{n_features} {} - - template - void ForEachBatch(DMatrix* p_fmat, EncAccessor&& acc, Fn&& fn) { - if (p_fmat->PageExists()) { - for (auto& page : p_fmat->GetBatches()) { - SparsePageView batch{ctx_, page, n_features_}; - auto loader = SparsePageLoaderNoShared{batch, acc}; - fn(std::move(loader), page.base_rowid); - } - } else { - p_fmat->Info().feature_types.SetDevice(ctx_->Device()); - auto feature_types = p_fmat->Info().feature_types.ConstDeviceSpan(); - - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - page.Impl()->Visit(ctx_, feature_types, [&](auto&& batch) { - using Acc = std::remove_reference_t; - auto loader = EllpackLoader{batch, - /*use_shared=*/false, - this->n_features_, - batch.NumRows(), - std::numeric_limits::quiet_NaN(), - std::forward(acc)}; - fn(std::move(loader), batch.base_rowid); - }); - } +template +void DispatchByBatchLoader(Context const* ctx, DMatrix* p_fmat, bst_feature_t n_features, + EncAccessor acc, Fn&& fn) { + using AccT = std::decay_t; + if (p_fmat->PageExists()) { + for (auto& page : p_fmat->GetBatches()) { + SparsePageView batch{ctx, page, n_features}; + auto loader = SparsePageLoaderNoShared{batch, acc}; + fn(std::move(loader), page.base_rowid); + } + } else { + p_fmat->Info().feature_types.SetDevice(ctx->Device()); + auto feature_types = p_fmat->Info().feature_types.ConstDeviceSpan(); + + for (auto const& page : p_fmat->GetBatches(ctx, StaticBatch(true))) { + page.Impl()->Visit(ctx, feature_types, [&](auto&& batch) { + using BatchT = std::remove_reference_t; + auto loader = EllpackLoader{batch, + /*use_shared=*/false, + n_features, + batch.NumRows(), + std::numeric_limits::quiet_NaN(), + AccT{acc}}; + fn(std::move(loader), batch.base_rowid); + }); } } +} - private: - Context const* ctx_; - bst_feature_t n_features_; -}; - -template -void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, - gbm::GBTreeModel const& model, Kernel&& launch) { +template +void LaunchShap(Context const* ctx, DMatrix* p_fmat, enc::DeviceColumnsView const& new_enc, + gbm::GBTreeModel const& model, Fn&& fn) { + auto n_features = model.learner_model_param->num_feature; if (model.Cats() && model.Cats()->HasCategorical() && new_enc.HasCategorical()) { auto [acc, mapping] = ::xgboost::cuda_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); - auto cfg = - LaunchConfig{ctx, model.learner_model_param->num_feature}; - launch(std::move(cfg), std::move(acc)); + DispatchByBatchLoader(ctx, p_fmat, n_features, std::move(acc), fn); } else { - auto cfg = - LaunchConfig{ctx, model.learner_model_param->num_feature}; - launch(std::move(cfg), NoOpAccessor{}); + DispatchByBatchLoader(ctx, p_fmat, n_features, NoOpAccessor{}, fn); } } } // namespace @@ -358,16 +344,10 @@ void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* ou dh::device_vector categories; ExtractPaths(ctx, &device_paths, model, d_model, &categories); - LaunchShap(ctx, new_enc, model, [&](auto&& cfg, auto&& acc) { - using Config = common::GetValueT; - using EncAccessor = typename Config::EncAccessorT; - - cfg.ForEachBatch( - p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { - auto begin = dh::tbegin(phis) + base_rowid * dim_size; - gpu_treeshap::GPUTreeShap>( - loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); - }); + LaunchShap(ctx, p_fmat, new_enc, model, [&](auto&& loader, bst_idx_t base_rowid) { + auto begin = dh::tbegin(phis) + base_rowid * dim_size; + gpu_treeshap::GPUTreeShap>( + loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); }); p_fmat->Info().base_margin_.SetDevice(ctx->Device()); @@ -414,16 +394,10 @@ void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, auto new_enc = p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; - LaunchShap(ctx, new_enc, model, [&](auto&& cfg, auto&& acc) { - using Config = common::GetValueT; - using EncAccessor = typename Config::EncAccessorT; - - cfg.ForEachBatch( - p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { - auto begin = dh::tbegin(phis) + base_rowid * dim_size; - gpu_treeshap::GPUTreeShapInteractions>( - loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); - }); + LaunchShap(ctx, p_fmat, new_enc, model, [&](auto&& loader, bst_idx_t base_rowid) { + auto begin = dh::tbegin(phis) + base_rowid * dim_size; + gpu_treeshap::GPUTreeShapInteractions>( + loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); }); p_fmat->Info().base_margin_.SetDevice(ctx->Device());