diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index ecf186176465..81cbb8434737 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -85,6 +85,7 @@ OBJECTS= \ $(PKGROOT)/src/data/iterative_dmatrix.o \ $(PKGROOT)/src/predictor/predictor.o \ $(PKGROOT)/src/predictor/cpu_predictor.o \ + $(PKGROOT)/src/predictor/interpretability/shap.o \ $(PKGROOT)/src/predictor/treeshap.o \ $(PKGROOT)/src/tree/constraints.o \ $(PKGROOT)/src/tree/param.o \ diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in index c25eb5f4212a..352c8a45922a 100644 --- a/R-package/src/Makevars.win.in +++ b/R-package/src/Makevars.win.in @@ -84,6 +84,7 @@ OBJECTS= \ $(PKGROOT)/src/data/iterative_dmatrix.o \ $(PKGROOT)/src/predictor/predictor.o \ $(PKGROOT)/src/predictor/cpu_predictor.o \ + $(PKGROOT)/src/predictor/interpretability/shap.o \ $(PKGROOT)/src/predictor/treeshap.o \ $(PKGROOT)/src/tree/constraints.o \ $(PKGROOT)/src/tree/param.o \ diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 020e0a59d1e8..a4c490dc0848 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -181,8 +181,7 @@ class Predictor { struct PredictorReg : public dmlc::FunctionRegEntryBase> {}; -#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ - static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \ - __make_##PredictorReg##_##UniqueId##__ = \ - ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name) +#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& __make_##PredictorReg##_##UniqueId##__ = \ + ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name) } // namespace xgboost diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 06f89242b1d0..594458dbbb93 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -112,8 +112,7 @@ void GBTree::Configure(Args const& cfg) { #if defined(XGBOOST_USE_SYCL) if (!sycl_predictor_) { - sycl_predictor_ = - std::unique_ptr(Predictor::Create("sycl_predictor", this->ctx_)); + sycl_predictor_ = std::unique_ptr(Predictor::Create("sycl_predictor", this->ctx_)); } sycl_predictor_->Configure(cfg); #endif // defined(XGBOOST_USE_SYCL) @@ -639,9 +638,9 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, return gpu_predictor_; } else { #if defined(XGBOOST_USE_SYCL) - common::AssertSYCLSupport(); - CHECK(sycl_predictor_); - return sycl_predictor_; + common::AssertSYCLSupport(); + CHECK(sycl_predictor_); + return sycl_predictor_; #endif // defined(XGBOOST_USE_SYCL) } @@ -676,7 +675,6 @@ void GPUDartInplacePredictInc(common::Span /*out_predts*/, common::Spanmodel_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented(); auto& predictor = this->GetPredictor(training, &p_out_preds->predictions, p_fmat); CHECK(predictor); - predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, - model_); + predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_); p_out_preds->version = 0; auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); auto n_groups = model_.learner_model_param->num_output_group; @@ -784,8 +781,8 @@ class Dart : public GBTree { GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(), predts.predictions.DeviceSpan(), w, n_rows, n_groups, grp_idx); } else { - auto &h_out_predts = p_out_preds->predictions.HostVector(); - auto &h_predts = predts.predictions.ConstHostVector(); + auto& h_out_predts = p_out_preds->predictions.HostVector(); + auto& h_predts = predts.predictions.ConstHostVector(); common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) { const size_t offset = ridx * n_groups + grp_idx; h_out_predts[offset] += (h_predts[offset] * w); @@ -942,10 +939,8 @@ class Dart : public GBTree { // size_t i = std::discrete_distribution(weight_drop.begin(), // weight_drop.end())(rnd); size_t i = std::discrete_distribution( - weight_drop_.size(), 0., static_cast(weight_drop_.size()), - [this](double x) -> double { - return weight_drop_[static_cast(x)]; - })(rnd); + weight_drop_.size(), 0., static_cast(weight_drop_.size()), + [this](double x) -> double { return weight_drop_[static_cast(x)]; })(rnd); idx_drop_.push_back(i); } } else { diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 739d196769f9..7d5c425ad6f2 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -40,10 +40,7 @@ enum class TreeMethod : int { }; // boosting process types -enum class TreeProcessType : int { - kDefault = 0, - kUpdate = 1 -}; +enum class TreeProcessType : int { kDefault = 0, kUpdate = 1 }; // Sampling type for dart weights. enum class DartSampleType : std::int32_t { @@ -72,15 +69,16 @@ struct GBTreeTrainParam : public XGBoostParameter { .set_default(TreeProcessType::kDefault) .add_enum("default", TreeProcessType::kDefault) .add_enum("update", TreeProcessType::kUpdate) - .describe("Whether to run the normal boosting process that creates new trees,"\ - " or to update the trees in an existing model."); + .describe( + "Whether to run the normal boosting process that creates new trees," + " or to update the trees in an existing model."); DMLC_DECLARE_ALIAS(updater_seq, updater); DMLC_DECLARE_FIELD(tree_method) .set_default(TreeMethod::kAuto) - .add_enum("auto", TreeMethod::kAuto) - .add_enum("approx", TreeMethod::kApprox) - .add_enum("exact", TreeMethod::kExact) - .add_enum("hist", TreeMethod::kHist) + .add_enum("auto", TreeMethod::kAuto) + .add_enum("approx", TreeMethod::kApprox) + .add_enum("exact", TreeMethod::kExact) + .add_enum("hist", TreeMethod::kHist) .describe("Choice of tree construction method."); } }; @@ -268,10 +266,9 @@ class GBTree : public GradientBooster { } }); } else { - LOG(FATAL) - << "Unknown feature importance type, expected one of: " - << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" - << importance_type; + LOG(FATAL) << "Unknown feature importance type, expected one of: " + << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" + << importance_type; } if (importance_type == "gain" || importance_type == "cover") { for (size_t i = 0; i < gain_map.size(); ++i) { @@ -291,9 +288,8 @@ class GBTree : public GradientBooster { [[nodiscard]] CatContainer const* Cats() const override { return this->model_.Cats(); } - void PredictLeaf(DMatrix* p_fmat, - HostDeviceVector* out_preds, - uint32_t layer_begin, uint32_t layer_end) override { + void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, uint32_t layer_begin, + uint32_t layer_end) override { auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: [0, " "n_iteration), use model slicing instead."; @@ -345,7 +341,7 @@ class GBTree : public GradientBooster { GBTreeTrainParam tparam_; // Tree training parameter tree::TrainParam tree_param_; - bool specified_updater_ {false}; + bool specified_updater_{false}; // the updaters that can be applied to each of tree std::vector> updaters_; // Predictors diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index bbe60941edee..0c2fd4baefc9 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -8,27 +8,28 @@ #include // for unique_ptr, shared_ptr #include // for vector -#include "../collective/allreduce.h" // for Allreduce -#include "../collective/communicator-inl.h" // for IsDistributed -#include "../common/bitfield.h" // for RBitField8 -#include "../common/column_matrix.h" // for ColumnMatrix -#include "../common/error_msg.h" // for InplacePredictProxy -#include "../common/math.h" // for CheckNAN -#include "../common/threading_utils.h" // for ParallelFor -#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter -#include "../data/cat_container.h" // for CatContainer -#include "../data/gradient_index.h" // for GHistIndexMatrix -#include "../data/proxy_dmatrix.h" // for DMatrixProxy -#include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam -#include "array_tree_layout.h" // for ProcessArrayTree -#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG -#include "gbtree_view.h" // for GBTreeModelView -#include "predict_fn.h" // for GetNextNode, GetNextNodeMulti -#include "treeshap.h" // for CalculateContributions -#include "utils.h" // for CheckProxyDMatrix -#include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... -#include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for Entry, DMatrix, MetaInfo, SparsePage, Batch... +#include "../collective/allreduce.h" // for Allreduce +#include "../collective/communicator-inl.h" // for IsDistributed +#include "../common/bitfield.h" // for RBitField8 +#include "../common/column_matrix.h" // for ColumnMatrix +#include "../common/error_msg.h" // for InplacePredictProxy +#include "../common/math.h" // for CheckNAN +#include "../common/threading_utils.h" // for ParallelFor +#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter +#include "../data/cat_container.h" // for CatContainer +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "../data/proxy_dmatrix.h" // for DMatrixProxy +#include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam +#include "array_tree_layout.h" // for ProcessArrayTree +#include "data_accessor.h" // for GHistIndexMatrixView, SparsePageView +#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG +#include "gbtree_view.h" // for GBTreeModelView +#include "interpretability/shap.h" // for ShapValues, ApproxFeatureImportance, ShapInteractionValues +#include "predict_fn.h" // for GetNextNode, GetNextNodeMulti +#include "utils.h" // for CheckProxyDMatrix +#include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for Entry, DMatrix, MetaInfo, SparsePage, Batch... #include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/learner.h" // for LearnerModelParam #include "xgboost/linalg.h" // for TensorView, All, VectorView, Tensor @@ -233,172 +234,6 @@ bool ShouldUseBlock(DMatrix *p_fmat) { using cpu_impl::MakeCatAccessor; -// Convert a single sample in batch view to FVec -template -struct DataToFeatVec { - void Fill(bst_idx_t ridx, RegTree::FVec *p_feats) const { - auto &feats = *p_feats; - auto n_valid = static_cast(this)->DoFill(ridx, feats.Data().data()); - feats.HasMissing(n_valid != feats.Size()); - } - - // Fill the data into the feature vector. - void FVecFill(common::Range1d const &block, bst_feature_t n_features, - common::Span s_feats_vec) const { - auto feats_vec = s_feats_vec.data(); - for (std::size_t i = 0; i < block.Size(); ++i) { - RegTree::FVec &feats = feats_vec[i]; - if (feats.Size() == 0) { - feats.Init(n_features); - } - this->Fill(block.begin() + i, &feats); - } - } - // Clear the feature vector. - static void FVecDrop(common::Span s_feats) { - auto p_feats = s_feats.data(); - for (size_t i = 0, n = s_feats.size(); i < n; ++i) { - p_feats[i].Drop(); - } - } -}; - -template -class SparsePageView : public DataToFeatVec> { - EncAccessor acc_; - HostSparsePageView const view_; - - public: - bst_idx_t const base_rowid; - - SparsePageView(HostSparsePageView const p, bst_idx_t base_rowid, EncAccessor acc) - : acc_{std::move(acc)}, view_{p}, base_rowid{base_rowid} {} - [[nodiscard]] std::size_t Size() const { return view_.Size(); } - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto p_data = view_[ridx].data(); - - for (std::size_t i = 0, n = view_[ridx].size(); i < n; ++i) { - auto const &entry = p_data[i]; - out[entry.index] = acc_(entry); - } - - return view_[ridx].size(); - } -}; - -template -class GHistIndexMatrixView : public DataToFeatVec> { - private: - GHistIndexMatrix const &page_; - EncAccessor acc_; - common::Span ft_; - - std::vector const &ptrs_; - std::vector const &mins_; - std::vector const &values_; - common::ColumnMatrix const &columns_; - - public: - bst_idx_t const base_rowid; - - public: - GHistIndexMatrixView(GHistIndexMatrix const &_page, EncAccessor acc, - common::Span ft) - : page_{_page}, - acc_{std::move(acc)}, - ft_{ft}, - ptrs_{_page.cut.Ptrs()}, - mins_{_page.cut.MinValues()}, - values_{_page.cut.Values()}, - columns_{page_.Transpose()}, - base_rowid{_page.base_rowid} {} - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto gridx = ridx + this->base_rowid; - auto n_features = page_.Features(); - - bst_idx_t n_non_missings = 0; - if (page_.IsDense()) { - common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { - using T = decltype(t); - auto ptr = this->page_.index.template data(); - auto rbeg = this->page_.row_ptr[ridx]; - for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - bst_bin_t bin_idx; - float fvalue; - if (common::IsCat(ft_, fidx)) { - bin_idx = page_.GetGindex(gridx, fidx); - fvalue = this->values_[bin_idx]; - } else { - bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; - fvalue = - common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); - } - out[fidx] = acc_(fvalue, fidx); - } - }); - n_non_missings += n_features; - } else { - for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - float fvalue = std::numeric_limits::quiet_NaN(); - bool is_cat = common::IsCat(ft_, fidx); - if (columns_.GetColumnType(fidx) == common::kSparseColumn) { - // Special handling for extremely sparse data. Just binary search. - auto bin_idx = page_.GetGindex(gridx, fidx); - if (bin_idx != -1) { - if (is_cat) { - fvalue = values_[bin_idx]; - } else { - fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, - bin_idx); - } - } - } else { - fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); - } - if (!common::CheckNAN(fvalue)) { - out[fidx] = acc_(fvalue, fidx); - n_non_missings++; - } - } - } - return n_non_missings; - } - - [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } -}; - -template -class AdapterView : public DataToFeatVec> { - Adapter const *adapter_; - float missing_; - EncAccessor acc_; - - public: - explicit AdapterView(Adapter const *adapter, float missing, EncAccessor acc) - : adapter_{adapter}, missing_{missing}, acc_{std::move(acc)} {} - - [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto const &batch = adapter_->Value(); - auto row = batch.GetLine(ridx); - bst_idx_t n_non_missings = 0; - for (size_t c = 0; c < row.Size(); ++c) { - auto e = row.GetElement(c); - if (missing_ != e.value && !common::CheckNAN(e.value)) { - auto fvalue = this->acc_(e); - out[e.column_idx] = fvalue; - n_non_missings++; - } - } - return n_non_missings; - } - - [[nodiscard]] bst_idx_t Size() const { return adapter_->NumRows(); } - - bst_idx_t const static base_rowid = 0; // NOLINT -}; - // Ordinal re-coder. struct EncAccessorPolicy { private: @@ -572,31 +407,6 @@ void PredictBatchByBlockKernel(DataView const &batch, HostModel const &model, }); } -float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx, - std::vector *mean_values) { - float result; - auto &node_mean_values = *mean_values; - if (tree.IsLeaf(nidx)) { - result = tree.LeafValue(nidx); - } else { - result = FillNodeMeanValues(tree, tree.LeftChild(nidx), mean_values) * - tree.Stat(tree.LeftChild(nidx)).sum_hess; - result += FillNodeMeanValues(tree, tree.RightChild(nidx), mean_values) * - tree.Stat(tree.RightChild(nidx)).sum_hess; - result /= tree.Stat(nidx).sum_hess; - } - node_mean_values[nidx] = result; - return result; -} - -void FillNodeMeanValues(tree::ScalarTreeView const &tree, std::vector *mean_values) { - auto n_nodes = tree.Size(); - if (static_cast(mean_values->size()) == n_nodes) { - return; - } - mean_values->resize(n_nodes); - FillNodeMeanValues(tree, 0, mean_values); -} } // anonymous namespace /** @@ -916,65 +726,6 @@ class CPUPredictor : public Predictor { }); } - template - void PredictContributionKernel(DataView batch, const MetaInfo &info, HostModel const &h_model, - linalg::VectorView base_score, - std::vector const *tree_weights, - std::vector> *mean_values, - ThreadTmp<1> *feat_vecs, std::vector *contribs, - bool approximate, int condition, - unsigned condition_feature) const { - const int num_feature = h_model.n_features; - const auto n_groups = h_model.n_groups; - CHECK_NE(n_groups, 0); - size_t const ncolumns = num_feature + 1; - CHECK_NE(ncolumns, 0); - auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device(); - auto base_margin = info.base_margin_.View(device); - - // parallel over local batch - common::ParallelFor(batch.Size(), this->ctx_->Threads(), [&](auto i) { - auto row_idx = batch.base_rowid + i; - RegTree::FVec &feats = feat_vecs->ThreadBuffer(1).front(); - if (feats.Size() == 0) { - feats.Init(num_feature); - } - std::vector this_tree_contribs(ncolumns); - // loop over all classes - for (bst_target_t gid = 0; gid < n_groups; ++gid) { - float *p_contribs = &(*contribs)[(row_idx * n_groups + gid) * ncolumns]; - batch.Fill(i, &feats); - // calculate contributions - for (bst_tree_t j = 0; j < h_model.tree_end; ++j) { - auto *tree_mean_values = &mean_values->at(j); - std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); - if (h_model.tree_groups[j] != gid) { - continue; - } - auto sc_tree = std::get(h_model.Trees()[j]); - if (!approximate) { - CalculateContributions(sc_tree, feats, tree_mean_values, &this_tree_contribs[0], - condition, condition_feature); - } else { - CalculateContributionsApprox(sc_tree, feats, tree_mean_values, &this_tree_contribs[0]); - } - for (size_t ci = 0; ci < ncolumns; ++ci) { - p_contribs[ci] += - this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); - } - } - feats.Drop(); - // add base margin to BIAS - if (base_margin.Size() != 0) { - CHECK_EQ(base_margin.Shape(1), n_groups); - p_contribs[ncolumns - 1] += base_margin(row_idx, gid); - } else { - p_contribs[ncolumns - 1] += base_score(gid); - } - } - }); - } - public: explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} @@ -1083,95 +834,23 @@ class CPUPredictor : public Predictor { void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, const gbm::GBTreeModel &model, bst_tree_t ntree_limit, - std::vector const *tree_weights, bool approximate, - int condition, unsigned condition_feature) const override { - CHECK(!model.learner_model_param->IsVectorLeaf()) - << "Predict contribution" << MTNotImplemented(); - CHECK(!p_fmat->Info().IsColumnSplit()) - << "Predict contribution support for column-wise data split is not yet implemented."; - auto const n_threads = this->ctx_->Threads(); - ThreadTmp<1> feat_vecs{n_threads}; - const MetaInfo &info = p_fmat->Info(); - // number of valid trees - ntree_limit = GetTreeLimit(model.trees, ntree_limit); - size_t const ncolumns = model.learner_model_param->num_feature + 1; - // allocate space for (number of features + bias) times the number of rows - std::vector &contribs = out_contribs->HostVector(); - contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); - // make sure contributions is zeroed, we could be reusing a previously - // allocated one - std::fill(contribs.begin(), contribs.end(), 0); - // initialize tree node mean values - std::vector> mean_values(ntree_limit); - common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { - FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); - }); - - auto const h_model = HostModel{DeviceOrd::CPU(), model, true, 0, ntree_limit, CopyViews{}}; - LaunchPredict(this->ctx_, p_fmat, model, [&](auto &&policy) { - policy.ForEachBatch([&](auto &&batch) { - PredictContributionKernel(batch, info, h_model, - model.learner_model_param->BaseScore(DeviceOrd::CPU()), - tree_weights, &mean_values, &feat_vecs, &contribs, approximate, - condition, condition_feature); - }); - }); + std::vector const *tree_weights, bool approximate, int condition, + unsigned condition_feature) const override { + if (approximate) { + interpretability::ApproxFeatureImportance(this->ctx_, p_fmat, out_contribs, model, + ntree_limit, tree_weights); + } else { + interpretability::ShapValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, + tree_weights, condition, condition_feature); + } } void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, bst_tree_t ntree_limit, std::vector const *tree_weights, bool approximate) const override { - CHECK(!model.learner_model_param->IsVectorLeaf()) - << "Predict interaction contribution" << MTNotImplemented(); - CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " - "column-wise data split is not yet implemented."; - const MetaInfo &info = p_fmat->Info(); - auto const ngroup = model.learner_model_param->num_output_group; - auto const ncolumns = model.learner_model_param->num_feature; - const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); - const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1); - const unsigned crow_chunk = ngroup * (ncolumns + 1); - - // allocate space for (number of features^2) times the number of rows and tmp off/on contribs - std::vector &contribs = out_contribs->HostVector(); - contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1)); - HostDeviceVector contribs_off_hdv(info.num_row_ * ngroup * (ncolumns + 1)); - auto &contribs_off = contribs_off_hdv.HostVector(); - HostDeviceVector contribs_on_hdv(info.num_row_ * ngroup * (ncolumns + 1)); - auto &contribs_on = contribs_on_hdv.HostVector(); - HostDeviceVector contribs_diag_hdv(info.num_row_ * ngroup * (ncolumns + 1)); - auto &contribs_diag = contribs_diag_hdv.HostVector(); - - // Compute the difference in effects when conditioning on each of the features on and off - // see: Axiomatic characterizations of probabilistic and - // cardinal-probabilistic interaction indices - PredictContribution(p_fmat, &contribs_diag_hdv, model, ntree_limit, tree_weights, approximate, - 0, 0); - for (size_t i = 0; i < ncolumns + 1; ++i) { - PredictContribution(p_fmat, &contribs_off_hdv, model, ntree_limit, tree_weights, approximate, - -1, i); - PredictContribution(p_fmat, &contribs_on_hdv, model, ntree_limit, tree_weights, approximate, - 1, i); - - for (size_t j = 0; j < info.num_row_; ++j) { - for (std::remove_const_t l = 0; l < ngroup; ++l) { - const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1); - const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1); - contribs[o_offset + i] = 0; - for (size_t k = 0; k < ncolumns + 1; ++k) { - // fill in the diagonal with additive effects, and off-diagonal with the interactions - if (k == i) { - contribs[o_offset + i] += contribs_diag[c_offset + k]; - } else { - contribs[o_offset + k] = - (contribs_on[c_offset + k] - contribs_off[c_offset + k]) / 2.0; - contribs[o_offset + i] -= contribs[o_offset + k]; - } - } - } - } - } + interpretability::ShapInteractionValues(this->ctx_, p_fmat, out_contribs, model, ntree_limit, + tree_weights, approximate); } }; diff --git a/src/predictor/data_accessor.h b/src/predictor/data_accessor.h new file mode 100644 index 000000000000..7c07da7b1fb5 --- /dev/null +++ b/src/predictor/data_accessor.h @@ -0,0 +1,190 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "../common/categorical.h" // for IsCat +#include "../common/column_matrix.h" // for ColumnMatrix +#include "../common/common.h" // for Range1d +#include "../common/hist_util.h" // for DispatchBinType, HistogramCuts +#include "../common/math.h" // for CheckNAN +#include "../data/cat_container.h" // for NoOpAccessor +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "xgboost/data.h" // for HostSparsePageView +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree::FVec + +namespace xgboost::predictor { +// Convert a single sample in batch view to FVec. +template +struct DataToFeatVec { + void Fill(bst_idx_t ridx, RegTree::FVec* p_feats) const { + auto& feats = *p_feats; + auto n_valid = static_cast(this)->DoFill(ridx, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + } + + // Fill the data into the feature vector. + void FVecFill(common::Range1d const& block, bst_feature_t n_features, + common::Span s_feats_vec) const { + auto feats_vec = s_feats_vec.data(); + for (std::size_t i = 0; i < block.Size(); ++i) { + RegTree::FVec& feats = feats_vec[i]; + if (feats.Size() == 0) { + feats.Init(n_features); + } + this->Fill(block.begin() + i, &feats); + } + } + // Clear the feature vector. + static void FVecDrop(common::Span s_feats) { + auto p_feats = s_feats.data(); + for (size_t i = 0, n = s_feats.size(); i < n; ++i) { + p_feats[i].Drop(); + } + } +}; + +template +class SparsePageView : public DataToFeatVec> { + EncAccessor acc_; + HostSparsePageView const view_; + + public: + bst_idx_t const base_rowid; + + SparsePageView(HostSparsePageView const p, bst_idx_t base_rowid, EncAccessor acc) + : acc_{std::move(acc)}, view_{p}, base_rowid{base_rowid} {} + + [[nodiscard]] std::size_t Size() const { return view_.Size(); } + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto p_data = view_[ridx].data(); + + for (std::size_t i = 0, n = view_[ridx].size(); i < n; ++i) { + auto const& entry = p_data[i]; + out[entry.index] = acc_(entry); + } + + return view_[ridx].size(); + } +}; + +template +class GHistIndexMatrixView : public DataToFeatVec> { + private: + GHistIndexMatrix const& page_; + EncAccessor acc_; + common::Span ft_; + + std::vector const& ptrs_; + std::vector const& mins_; + std::vector const& values_; + common::ColumnMatrix const& columns_; + + public: + bst_idx_t const base_rowid; + + public: + GHistIndexMatrixView(GHistIndexMatrix const& page, EncAccessor acc, + common::Span ft) + : page_{page}, + acc_{std::move(acc)}, + ft_{ft}, + ptrs_{page.cut.Ptrs()}, + mins_{page.cut.MinValues()}, + values_{page.cut.Values()}, + columns_{page.Transpose()}, + base_rowid{page.base_rowid} {} + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto gridx = ridx + this->base_rowid; + auto n_features = page_.Features(); + + bst_idx_t n_non_missings = 0; + if (page_.IsDense()) { + common::DispatchBinType(page_.index.GetBinTypeSize(), [&](auto t) { + using T = decltype(t); + auto ptr = this->page_.index.template data(); + auto rbeg = this->page_.row_ptr[ridx]; + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + bst_bin_t bin_idx; + float fvalue; + if (common::IsCat(ft_, fidx)) { + bin_idx = page_.GetGindex(gridx, fidx); + fvalue = this->values_[bin_idx]; + } else { + bin_idx = ptr[rbeg + fidx] + page_.index.Offset()[fidx]; + fvalue = + common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); + } + out[fidx] = acc_(fvalue, fidx); + } + }); + n_non_missings += n_features; + } else { + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + float fvalue = std::numeric_limits::quiet_NaN(); + bool is_cat = common::IsCat(ft_, fidx); + if (columns_.GetColumnType(fidx) == common::kSparseColumn) { + // Special handling for extremely sparse data. Just binary search. + auto bin_idx = page_.GetGindex(gridx, fidx); + if (bin_idx != -1) { + if (is_cat) { + fvalue = values_[bin_idx]; + } else { + fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, + bin_idx); + } + } + } else { + fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); + } + if (!common::CheckNAN(fvalue)) { + out[fidx] = acc_(fvalue, fidx); + n_non_missings++; + } + } + } + return n_non_missings; + } + + [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } +}; + +template +class AdapterView : public DataToFeatVec> { + Adapter const* adapter_; + float missing_; + EncAccessor acc_; + + public: + explicit AdapterView(Adapter const* adapter, float missing, EncAccessor acc) + : adapter_{adapter}, missing_{missing}, acc_{std::move(acc)} {} + + [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float* out) const { + auto const& batch = adapter_->Value(); + auto row = batch.GetLine(ridx); + bst_idx_t n_non_missings = 0; + for (size_t c = 0; c < row.Size(); ++c) { + auto e = row.GetElement(c); + if (missing_ != e.value && !common::CheckNAN(e.value)) { + auto fvalue = this->acc_(e); + out[e.column_idx] = fvalue; + n_non_missings++; + } + } + return n_non_missings; + } + + [[nodiscard]] bst_idx_t Size() const { return adapter_->NumRows(); } + + bst_idx_t const static base_rowid = 0; // NOLINT +}; +} // namespace xgboost::predictor diff --git a/src/predictor/gpu_data_accessor.cuh b/src/predictor/gpu_data_accessor.cuh new file mode 100644 index 000000000000..ab068d54164c --- /dev/null +++ b/src/predictor/gpu_data_accessor.cuh @@ -0,0 +1,113 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#pragma once + +#include +#include +#include + +#include "../common/categorical.h" // for IsCat +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for Entry, SparsePage +#include "xgboost/span.h" // for Span + +namespace xgboost::predictor { +struct SparsePageView { + common::Span d_data; + common::Span d_row_ptr; + bst_feature_t num_features; + + SparsePageView() = default; + explicit SparsePageView(Context const* ctx, SparsePage const& page, bst_feature_t n_features) + : d_data{[&] { + page.data.SetDevice(ctx->Device()); + return page.data.ConstDeviceSpan(); + }()}, + d_row_ptr{[&] { + page.offset.SetDevice(ctx->Device()); + return page.offset.ConstDeviceSpan(); + }()}, + num_features{n_features} {} + + [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { + // Binary search + auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; + auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; + if (end_ptr - begin_ptr == this->NumCols()) { + // Bypass span check for dense data + return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; + } + common::Span::iterator previous_middle; + while (end_ptr != begin_ptr) { + auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; + if (middle == previous_middle) { + break; + } else { + previous_middle = middle; + } + + if (middle->index == fidx) { + return middle->fvalue; + } else if (middle->index < fidx) { + begin_ptr = middle; + } else { + end_ptr = middle; + } + } + // Value is missing + return std::numeric_limits::quiet_NaN(); + } + + [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } + [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } +}; + +template +struct SparsePageLoaderNoShared { + public: + using SupportShmemLoad = std::false_type; + + SparsePageView data; + EncAccessor acc; + + template + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { + return acc(data.GetElement(ridx, fidx), fidx); + } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } +}; + +template +struct EllpackLoader { + public: + using SupportShmemLoad = std::false_type; + + Accessor matrix; + EncAccessor acc; + + XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t /*n_features*/, + bst_idx_t /*n_samples*/, float /*missing*/, EncAccessor&& acc) + : matrix{std::move(m)}, acc{std::forward(acc)} {} + + [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { + auto gidx = matrix.template GetBinIndex(ridx, fidx); + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + if (common::IsCat(matrix.feature_types, fidx)) { + return this->acc(matrix.gidx_fvalue_map[gidx], fidx); + } + // The gradient index needs to be shifted by one as min values are not included in the + // cuts. + if (gidx == matrix.feature_segments[fidx]) { + return matrix.min_fvalue[fidx]; + } + return matrix.gidx_fvalue_map[gidx - 1]; + } + + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; } +}; +} // namespace xgboost::predictor diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index b5ce1a862255..43d2faa47c23 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,7 +1,6 @@ /** * Copyright 2017-2026, XGBoost Contributors */ -#include #include #include #include @@ -28,6 +27,8 @@ #include "../gbm/gbtree_model.h" #include "../tree/tree_view.h" #include "gbtree_view.h" // for GBTreeModelView +#include "gpu_data_accessor.cuh" +#include "interpretability/shap.h" #include "predict_fn.h" #include "utils.h" // for CheckProxyDMatrix #include "xgboost/data.h" @@ -42,55 +43,6 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor); using cuda_impl::StaticBatch; -struct SparsePageView { - common::Span d_data; - common::Span d_row_ptr; - bst_feature_t num_features; - - SparsePageView() = default; - explicit SparsePageView(Context const* ctx, SparsePage const& page, bst_feature_t n_features) - : d_data{[&] { - page.data.SetDevice(ctx->Device()); - return page.data.ConstDeviceSpan(); - }()}, - d_row_ptr{[&] { - page.offset.SetDevice(ctx->Device()); - return page.offset.ConstDeviceSpan(); - }()}, - num_features{n_features} {} - - [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { - // Binary search - auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; - auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; - if (end_ptr - begin_ptr == this->NumCols()) { - // Bypass span check for dense data - return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; - } - common::Span::iterator previous_middle; - while (end_ptr != begin_ptr) { - auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; - if (middle == previous_middle) { - break; - } else { - previous_middle = middle; - } - - if (middle->index == fidx) { - return middle->fvalue; - } else if (middle->index < fidx) { - begin_ptr = middle; - } else { - end_ptr = middle; - } - } - // Value is missing - return std::numeric_limits::quiet_NaN(); - } - [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } - [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } -}; - template struct SparsePageLoader { public: @@ -135,36 +87,6 @@ struct SparsePageLoader { } }; -template -struct EllpackLoader { - public: - using SupportShmemLoad = std::false_type; - - Accessor matrix; - EncAccessor acc; - - XGBOOST_DEVICE EllpackLoader(Accessor m, bool /*use_shared*/, bst_feature_t /*n_features*/, - bst_idx_t /*n_samples*/, float /*missing*/, EncAccessor&& acc) - : matrix{std::move(m)}, acc{std::forward(acc)} {} - [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { - auto gidx = matrix.template GetBinIndex(ridx, fidx); - if (gidx == -1) { - return std::numeric_limits::quiet_NaN(); - } - if (common::IsCat(matrix.feature_types, fidx)) { - return this->acc(matrix.gidx_fvalue_map[gidx], fidx); - } - // The gradient index needs to be shifted by one as min values are not included in the - // cuts. - if (gidx == matrix.feature_segments[fidx]) { - return matrix.min_fvalue[fidx]; - } - return matrix.gidx_fvalue_map[gidx - 1]; - } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; } -}; - /** * @brief Use for in-place predict. */ @@ -258,11 +180,12 @@ __global__ void PredictLeafKernel(Data data, common::Span d_t common::Span d_out_predictions, bst_tree_t tree_begin, bst_tree_t tree_end, bst_feature_t num_features, bool use_shared, float missing, EncAccessor acc) { + auto n_rows = data.NumRows(); bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; - if (ridx >= data.NumRows()) { + if (ridx >= n_rows) { return; } - Loader loader{std::move(data), use_shared, num_features, data.NumRows(), missing, std::move(acc)}; + Loader loader{std::move(data), use_shared, num_features, n_rows, missing, std::move(acc)}; for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { auto const& d_tree = d_trees[tree_idx - tree_begin]; cuda::std::visit( @@ -285,9 +208,10 @@ __global__ void PredictKernel(Data data, common::Span d_trees common::Span d_tree_groups, bst_feature_t num_features, bool use_shared, bst_target_t n_groups, float missing, EncAccessor acc) { + auto n_rows = data.NumRows(); bst_idx_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - Loader loader{std::move(data), use_shared, num_features, data.NumRows(), missing, std::move(acc)}; - if (global_idx >= data.NumRows()) { + Loader loader{std::move(data), use_shared, num_features, n_rows, missing, std::move(acc)}; + if (global_idx >= n_rows) { return; } @@ -340,221 +264,6 @@ struct CopyViews { using DeviceModel = GBTreeModelView; } // namespace -struct ShapSplitCondition { - ShapSplitCondition() = default; - XGBOOST_DEVICE - ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, bool is_missing_branch, - common::CatBitField cats) - : feature_lower_bound(feature_lower_bound), - feature_upper_bound(feature_upper_bound), - is_missing_branch(is_missing_branch), - categories{std::move(cats)} { - assert(feature_lower_bound <= feature_upper_bound); - } - - /*! Feature values >= lower and < upper flow down this path. */ - float feature_lower_bound; - float feature_upper_bound; - /*! Feature value set to true flow down this path. */ - common::CatBitField categories; - /*! Do missing values flow down this path? */ - bool is_missing_branch; - - // Does this instance flow down this path? - [[nodiscard]] XGBOOST_DEVICE bool EvaluateSplit(float x) const { - // is nan - if (isnan(x)) { - return is_missing_branch; - } - if (categories.Capacity() != 0) { - auto cat = static_cast(x); - return categories.Check(cat); - } else { - return x >= feature_lower_bound && x < feature_upper_bound; - } - } - - // the &= op in bitfiled is per cuda thread, this one loops over the entire - // bitfield. - XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l, - common::CatBitField r) { - if (l.Data() == r.Data()) { - return l; - } - if (l.Capacity() > r.Capacity()) { - cuda::std::swap(l, r); - } - for (size_t i = 0; i < r.Bits().size(); ++i) { - l.Bits()[i] &= r.Bits()[i]; - } - return l; - } - - // Combine two split conditions on the same feature - XGBOOST_DEVICE void Merge(ShapSplitCondition other) { - // Combine duplicate features - if (categories.Capacity() != 0 || other.categories.Capacity() != 0) { - categories = Intersect(categories, other.categories); - } else { - feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); - feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); - } - is_missing_branch = is_missing_branch && other.is_missing_branch; - } -}; - -struct PathInfo { - std::size_t length; - // Node index in tree. - // -1 if not a leaf (internal split node) - bst_node_t nidx; - bst_tree_t tree_idx; - - [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return nidx != -1; } -}; -static_assert(sizeof(PathInfo) == 16); - -auto MakeTreeSegments(Context const* ctx, bst_tree_t tree_begin, bst_tree_t tree_end, - gbm::GBTreeModel const& model) { - // Copy decision trees to device - auto tree_segments = HostDeviceVector({}, ctx->Device()); - auto& h_tree_segments = tree_segments.HostVector(); - h_tree_segments.reserve((tree_end - tree_begin) + 1); - std::size_t sum = 0; - h_tree_segments.push_back(sum); - for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - auto const& p_tree = model.trees.at(tree_idx); - CHECK(!p_tree->IsMultiTarget()) << " SHAP " << MTNotImplemented(); - sum += p_tree->Size(); - h_tree_segments.push_back(sum); - } - return tree_segments; -} - -// Transform model into path element form for GPUTreeShap -void ExtractPaths(Context const* ctx, - dh::device_vector>* paths, - gbm::GBTreeModel const& h_model, DeviceModel const& d_model, - dh::device_vector* path_categories) { - curt::SetDevice(ctx->Ordinal()); - - // Path length and tree index for all leaf nodes - dh::caching_device_vector info(d_model.n_nodes); - auto d_trees = d_model.Trees(); // subset of trees - auto tree_segments = MakeTreeSegments(ctx, d_model.tree_begin, d_model.tree_end, h_model); - CHECK_EQ(tree_segments.ConstHostVector().back(), d_model.n_nodes); - auto d_tree_segments = tree_segments.ConstDeviceSpan(); - - auto path_it = dh::MakeIndexTransformIter( - cuda::proclaim_return_type([=] __device__(size_t idx) -> PathInfo { - bst_tree_t const tree_idx = dh::SegmentId(d_tree_segments, idx); - bst_node_t const nidx = idx - d_tree_segments[tree_idx]; - auto const& tree = cuda::std::get(d_trees[tree_idx]); - if (!tree.IsLeaf(nidx) || tree.IsDeleted(nidx)) { - // -1 if it's an internal split node - return PathInfo{0, -1, 0}; - } - // Get the path length for leaf - std::size_t path_length = 1; - auto iter_nidx = nidx; - while (!tree.IsRoot(iter_nidx)) { - iter_nidx = tree.Parent(iter_nidx); - path_length++; - } - return PathInfo{path_length, nidx, tree_idx}; - })); - auto end = thrust::copy_if( - ctx->CUDACtx()->CTP(), path_it, path_it + d_model.n_nodes, info.begin(), - cuda::proclaim_return_type([=] __device__(PathInfo const& e) { return e.IsLeaf(); })); - - info.resize(end - info.begin()); - using LenT = decltype(std::declval().length); - auto length_iterator = dh::MakeTransformIterator( - info.begin(), cuda::proclaim_return_type( - [=] __device__(PathInfo const& info) { return info.length; })); - dh::caching_device_vector path_segments(info.size() + 1); - thrust::exclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size() + 1, - path_segments.begin()); - - paths->resize(path_segments.back()); - - auto d_paths = dh::ToSpan(*paths); - auto d_info = info.data().get(); - auto d_tree_groups = d_model.tree_groups; - auto d_path_segments = path_segments.data().get(); - - std::size_t max_cat = 0; - if (std::any_of(h_model.trees.cbegin(), h_model.trees.cend(), - [](auto const& p_tree) { return p_tree->HasCategoricalSplit(); })) { - auto max_elem_it = dh::MakeIndexTransformIter([=] __device__(std::size_t i) -> std::size_t { - auto tree_idx = dh::SegmentId(d_tree_segments, i); - auto nidx = i - d_tree_segments[tree_idx]; - return cuda::std::get(d_trees[tree_idx]) - .GetCategoriesMatrix() - .node_ptr[nidx] - .size; - }); - auto max_cat_it = - thrust::max_element(ctx->CUDACtx()->CTP(), max_elem_it, max_elem_it + d_model.n_nodes); - dh::CachingDeviceUVector d_max_cat(1); - auto s_max_cat = dh::ToSpan(d_max_cat); - dh::LaunchN(1, ctx->CUDACtx()->Stream(), - [=] __device__(std::size_t) { s_max_cat[0] = *max_cat_it; }); - dh::safe_cuda( - cudaMemcpy(&max_cat, s_max_cat.data(), s_max_cat.size_bytes(), cudaMemcpyDeviceToHost)); - CHECK_GE(max_cat, 1); - path_categories->resize(max_cat * paths->size()); - } - - common::Span d_path_categories = dh::ToSpan(*path_categories); - - dh::LaunchN(info.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { - auto path_info = d_info[idx]; - auto tree = cuda::std::get(d_trees[path_info.tree_idx]); - std::int32_t group = d_tree_groups[path_info.tree_idx]; - auto child_nidx = path_info.nidx; - - float v = tree.LeafValue(child_nidx); - const float inf = std::numeric_limits::infinity(); - size_t output_position = d_path_segments[idx + 1] - 1; - - while (!tree.IsRoot(child_nidx)) { - auto parent_nidx = tree.Parent(child_nidx); - double child_cover = tree.SumHess(child_nidx); - double parent_cover = tree.SumHess(parent_nidx); - double zero_fraction = child_cover / parent_cover; - - bool is_left_path = tree.LeftChild(parent_nidx) == child_nidx; - bool is_missing_path = (!tree.DefaultLeft(parent_nidx) && !is_left_path) || - (tree.DefaultLeft(parent_nidx) && is_left_path); - - float lower_bound = -inf; - float upper_bound = inf; - common::CatBitField bits; - if (common::IsCat(tree.cats.split_type, tree.Parent(child_nidx))) { - auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat); - auto node_cats = tree.NodeCats(tree.Parent(child_nidx)); - SPAN_CHECK(path_cats.size() >= node_cats.size()); - for (size_t i = 0; i < node_cats.size(); ++i) { - path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i]; - } - bits = common::CatBitField{path_cats}; - } else { - lower_bound = is_left_path ? -inf : tree.SplitCond(parent_nidx); - upper_bound = is_left_path ? tree.SplitCond(parent_nidx) : inf; - } - d_paths[output_position--] = gpu_treeshap::PathElement{ - idx, tree.SplitIndex(parent_nidx), - group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits}, - zero_fraction, v}; - - child_nidx = parent_nidx; - } - // Root node has feature -1 - d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v}; - }); -} - namespace { template [[nodiscard]] std::size_t SharedMemoryBytes(std::size_t n_features, std::size_t max_shmem_bytes) { @@ -777,23 +486,6 @@ class ColumnSplitHelper { using cuda_impl::MakeCatAccessor; -template -struct ShapSparsePageLoader { - public: - using SupportShmemLoad = std::false_type; - - SparsePageView data; - EncAccessor acc; - - template - [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { - auto fvalue = data.GetElement(ridx, fidx); - return acc(fvalue, fidx); - } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } - [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } -}; - // Provide configuration for launching the predict kernel. template class LaunchConfig { @@ -888,35 +580,6 @@ class LaunchConfig { } } } - // Used by the SHAP methods. - template - void ForEachBatch(DMatrix* p_fmat, EncAccessor&& acc, Fn&& fn) { - if (p_fmat->PageExists()) { - for (auto& page : p_fmat->GetBatches()) { - // Shap kernel doesn't use shared memory to stage data. - SparsePageView batch{ctx_, page, n_features_}; - auto loader = ShapSparsePageLoader{batch, acc}; - fn(std::move(loader), page.base_rowid); - } - } else { - p_fmat->Info().feature_types.SetDevice(ctx_->Device()); - auto feature_types = p_fmat->Info().feature_types.ConstDeviceSpan(); - - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - page.Impl()->Visit(ctx_, feature_types, [&](auto&& batch) { - using Acc = std::remove_reference_t; - // No shared memory use for ellpack - auto loader = EllpackLoader{batch, - /*use_shared=*/false, - this->n_features_, - batch.NumRows(), - std::numeric_limits::quiet_NaN(), - std::forward(acc)}; - fn(std::move(loader), batch.base_rowid); - }); - } - } - } }; template @@ -947,20 +610,6 @@ void LaunchPredict(Context const* ctx, bool is_dense, enc::DeviceColumnsView con } } -template -void LaunchShap(Context const* ctx, enc::DeviceColumnsView const& new_enc, - gbm::GBTreeModel const& model, Kernel&& launch) { - if (model.Cats() && model.Cats()->HasCategorical() && new_enc.HasCategorical()) { - auto [acc, mapping] = MakeCatAccessor(ctx, new_enc, model.Cats()); - auto cfg = - LaunchConfig{ctx, model.learner_model_param->num_feature}; - launch(std::move(cfg), std::move(acc)); - } else { - auto cfg = - LaunchConfig{ctx, model.learner_model_param->num_feature}; - launch(std::move(cfg), NoOpAccessor{}); - } -} } // anonymous namespace class GPUPredictor : public xgboost::Predictor { @@ -993,10 +642,11 @@ class GPUPredictor : public xgboost::Predictor { bst_idx_t batch_offset = 0; cfg.ForEachBatch(p_fmat, [&](auto&& loader_t, auto&& batch) { using Loader = typename common::GetValueT; + auto n_rows = batch.NumRows(); cfg.template LaunchPredictKernel(std::move(batch), std::numeric_limits::quiet_NaN(), n_features, d_model, acc, batch_offset, out_preds); - batch_offset += batch.NumRows() * model.learner_model_param->OutputLength(); + batch_offset += n_rows * model.learner_model_param->OutputLength(); }); }); } @@ -1088,61 +738,11 @@ class GPUPredictor : public xgboost::Predictor { std::vector const* tree_weights, bool approximate, int, unsigned) const override { xgboost_NVTX_FN_RANGE(); - StringView not_implemented{ - "contribution is not implemented in the GPU predictor, use CPU instead."}; if (approximate) { - LOG(FATAL) << "Approximated " << not_implemented; + LOG(FATAL) << "Approximated contribution is not implemented in the GPU predictor, use CPU " + "instead."; } - if (tree_weights != nullptr) { - LOG(FATAL) << "Dart booster feature " << not_implemented; - } - CHECK(!p_fmat->Info().IsColumnSplit()) - << "Predict contribution support for column-wise data split is not yet implemented."; - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); - out_contribs->SetDevice(ctx_->Device()); - tree_end = GetTreeLimit(model.trees, tree_end); - - const int ngroup = model.learner_model_param->num_output_group; - CHECK_NE(ngroup, 0); - // allocate space for (number of features + bias) times the number of rows - size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias - auto dim_size = contributions_columns * model.learner_model_param->num_output_group; - // Output shape: [n_samples, n_classes, n_features + 1] - out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); - out_contribs->Fill(0.0f); - auto phis = out_contribs->DeviceSpan(); - - dh::device_vector> device_paths; - DeviceModel d_model{this->ctx_->Device(), model, true, 0, tree_end, CopyViews{this->ctx_}}; - - auto new_enc = - p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx_) : enc::DeviceColumnsView{}; - - dh::device_vector categories; - ExtractPaths(ctx_, &device_paths, model, d_model, &categories); - - LaunchShap(this->ctx_, new_enc, model, [&](auto&& cfg, auto&& acc) { - using Config = common::GetValueT; - using EncAccessor = typename Config::EncAccessorT; - - cfg.ForEachBatch( - p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { - auto begin = dh::tbegin(phis) + base_rowid * dim_size; - gpu_treeshap::GPUTreeShap>( - loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); - }); - }); - - // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(ctx_->Device()); - const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); - - auto base_score = model.learner_model_param->BaseScore(ctx_); - bst_idx_t n_samples = p_fmat->Info().num_row_; - dh::LaunchN(n_samples * ngroup, ctx_->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { - auto [_, gid] = linalg::UnravelIndex(idx, n_samples, ngroup); - phis[(idx + 1) * contributions_columns - 1] += margin.empty() ? base_score(gid) : margin[idx]; - }); + interpretability::ShapValues(ctx_, p_fmat, out_contribs, model, tree_end, tree_weights, 0, 0); } void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, @@ -1150,61 +750,12 @@ class GPUPredictor : public xgboost::Predictor { std::vector const* tree_weights, bool approximate) const override { xgboost_NVTX_FN_RANGE(); - std::string not_implemented{ - "contribution is not implemented in GPU predictor, use cpu instead."}; if (approximate) { - LOG(FATAL) << "Approximated " << not_implemented; - } - if (tree_weights != nullptr) { - LOG(FATAL) << "Dart booster feature " << not_implemented; + LOG(FATAL) << "Approximated contribution is not implemented in GPU predictor, use cpu " + "instead."; } - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); - out_contribs->SetDevice(ctx_->Device()); - tree_end = GetTreeLimit(model.trees, tree_end); - - const int ngroup = model.learner_model_param->num_output_group; - CHECK_NE(ngroup, 0); - // allocate space for (number of features + bias) times the number of rows - size_t contributions_columns = model.learner_model_param->num_feature + 1; // +1 for bias - auto dim_size = - contributions_columns * contributions_columns * model.learner_model_param->num_output_group; - out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); - out_contribs->Fill(0.0f); - auto phis = out_contribs->DeviceSpan(); - - dh::device_vector> device_paths; - DeviceModel d_model{this->ctx_->Device(), model, true, 0, tree_end, CopyViews{this->ctx_}}; - - dh::device_vector categories; - ExtractPaths(ctx_, &device_paths, model, d_model, &categories); - auto new_enc = - p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx_) : enc::DeviceColumnsView{}; - - LaunchShap(this->ctx_, new_enc, model, [&](auto&& cfg, auto&& acc) { - using Config = common::GetValueT; - using EncAccessor = typename Config::EncAccessorT; - - cfg.ForEachBatch( - p_fmat, std::forward(acc), [&](auto&& loader, bst_idx_t base_rowid) { - auto begin = dh::tbegin(phis) + base_rowid * dim_size; - gpu_treeshap::GPUTreeShapInteractions>( - loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); - }); - }); - - // Add the base margin term to last column - p_fmat->Info().base_margin_.SetDevice(ctx_->Device()); - const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); - - auto base_score = model.learner_model_param->BaseScore(ctx_); - size_t n_features = model.learner_model_param->num_feature; - bst_idx_t n_samples = p_fmat->Info().num_row_; - dh::LaunchN(n_samples * ngroup, ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) { - auto [ridx, gidx] = linalg::UnravelIndex(idx, n_samples, ngroup); - phis[gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gidx, n_features, n_features, - n_features)] += - margin.empty() ? base_score(gidx) : margin[idx]; - }); + interpretability::ShapInteractionValues(ctx_, p_fmat, out_contribs, model, tree_end, + tree_weights, approximate); } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, @@ -1234,6 +785,7 @@ class GPUPredictor : public xgboost::Predictor { cfg.ForEachBatch(p_fmat, [&](auto&& loader_t, auto&& batch) { using Loader = typename common::GetValueT; using Config = common::GetValueT; + auto n_rows = batch.NumRows(); auto kernel = PredictLeafKernel, Config::HasMissing(), typename Config::EncAccessorT>; cfg.template Launch(kernel, std::move(batch), d_model.Trees(), @@ -1242,7 +794,7 @@ class GPUPredictor : public xgboost::Predictor { cfg.UseShared(), std::numeric_limits::quiet_NaN(), std::forward(acc)); - batch_offset += batch.NumRows(); + batch_offset += n_rows; }); }); } diff --git a/src/predictor/interpretability/shap.cc b/src/predictor/interpretability/shap.cc new file mode 100644 index 000000000000..025b47f2abd7 --- /dev/null +++ b/src/predictor/interpretability/shap.cc @@ -0,0 +1,299 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#include "shap.h" + +#include // for fill +#include // for remove_const_t +#include // for vector + +#include "../../common/threading_utils.h" // for ParallelFor +#include "../../gbm/gbtree_model.h" // for GBTreeModel +#include "../../tree/tree_view.h" // for ScalarTreeView +#include "../data_accessor.h" // for GHistIndexMatrixView +#include "../predict_fn.h" // for GetTreeLimit +#include "../treeshap.h" // for CalculateContributions +#include "dmlc/omp.h" // for omp_get_thread_num +#include "xgboost/base.h" // for bst_omp_uint +#include "xgboost/logging.h" // for CHECK +#include "xgboost/multi_target_tree_model.h" // for MTNotImplemented + +namespace xgboost::interpretability { +namespace { +void ValidateTreeWeights(std::vector const *tree_weights, bst_tree_t tree_end) { + if (tree_weights == nullptr) { + return; + } + CHECK_GE(tree_weights->size(), static_cast(tree_end)); +} + +float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx, + std::vector *mean_values) { + float result; + auto &node_mean_values = *mean_values; + if (tree.IsLeaf(nidx)) { + result = tree.LeafValue(nidx); + } else { + result = FillNodeMeanValues(tree, tree.LeftChild(nidx), mean_values) * + tree.Stat(tree.LeftChild(nidx)).sum_hess; + result += FillNodeMeanValues(tree, tree.RightChild(nidx), mean_values) * + tree.Stat(tree.RightChild(nidx)).sum_hess; + result /= tree.Stat(nidx).sum_hess; + } + node_mean_values[nidx] = result; + return result; +} + +void FillNodeMeanValues(tree::ScalarTreeView const &tree, std::vector *mean_values) { + auto n_nodes = tree.Size(); + if (static_cast(mean_values->size()) == n_nodes) { + return; + } + mean_values->resize(n_nodes); + FillNodeMeanValues(tree, 0, mean_values); +} + +void CalculateApproxContributions(tree::ScalarTreeView const &tree, RegTree::FVec const &feats, + std::vector *mean_values, + std::vector *out_contribs) { + CHECK_EQ(out_contribs->size(), feats.Size() + 1); + CalculateContributionsApprox(tree, feats, mean_values, out_contribs->data()); +} + +template +void DispatchByBatchView(Context const *ctx, DMatrix *p_fmat, EncAccessor acc, Fn &&fn) { + using AccT = std::decay_t; + if (p_fmat->PageExists()) { + for (auto const &page : p_fmat->GetBatches()) { + predictor::SparsePageView view{page.GetView(), page.base_rowid, acc}; + fn(view); + } + } else { + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &page : p_fmat->GetBatches(ctx, {})) { + predictor::GHistIndexMatrixView view{page, acc, ft}; + fn(view); + } + } +} + +template +void LaunchShap(Context const *ctx, DMatrix *p_fmat, gbm::GBTreeModel const &model, Fn &&fn) { + if (model.Cats()->HasCategorical() && p_fmat->Cats()->NeedRecode()) { + auto new_enc = p_fmat->Cats()->HostView(); + auto [acc, mapping] = ::xgboost::cpu_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + DispatchByBatchView(ctx, p_fmat, acc, fn); + } else { + DispatchByBatchView(ctx, p_fmat, NoOpAccessor{}, fn); + } +} +} // namespace + +namespace cpu_impl { +void ShapValues(Context const *ctx, DMatrix *p_fmat, HostDeviceVector *out_contribs, + gbm::GBTreeModel const &model, bst_tree_t tree_end, + std::vector const *tree_weights, int condition, unsigned condition_feature) { + CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + // number of valid trees + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + CHECK_GE(tree_end, 0); + ValidateTreeWeights(tree_weights, tree_end); + auto const n_trees = static_cast(tree_end); + auto const n_threads = ctx->Threads(); + size_t const ncolumns = model.learner_model_param->num_feature + 1; + // allocate space for (number of features + bias) times the number of rows + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); + // make sure contributions is zeroed, we could be reusing a previously allocated one + std::fill(contribs.begin(), contribs.end(), 0); + // initialize tree node mean values + std::vector> mean_values(n_trees); + common::ParallelFor(n_trees, n_threads, [&](auto i) { + FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); + }); + + auto const n_groups = model.learner_model_param->num_output_group; + CHECK_NE(n_groups, 0); + auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); + auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + std::vector feats_tloc(n_threads); + std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); + + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto base_margin = info.base_margin_.View(device); + + auto process_view = [&](auto &&view) { + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateContributions(sc_tree, feats, &mean_values[j], this_tree_contribs.data(), + condition, condition_feature); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + }; + + LaunchShap(ctx, p_fmat, model, process_view); +} + +void ApproxFeatureImportance(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights) { + CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + CHECK_GE(tree_end, 0); + ValidateTreeWeights(tree_weights, tree_end); + auto const n_trees = static_cast(tree_end); + auto const n_threads = ctx->Threads(); + size_t const ncolumns = model.learner_model_param->num_feature + 1; + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); + std::fill(contribs.begin(), contribs.end(), 0); + std::vector> mean_values(n_trees); + common::ParallelFor(n_trees, n_threads, [&](auto i) { + FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i])); + }); + + auto const n_groups = model.learner_model_param->num_output_group; + CHECK_NE(n_groups, 0); + auto const base_score = model.learner_model_param->BaseScore(DeviceOrd::CPU()); + auto const h_tree_groups = model.TreeGroups(DeviceOrd::CPU()); + std::vector feats_tloc(n_threads); + std::vector> contribs_tloc(n_threads, std::vector(ncolumns)); + + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto base_margin = info.base_margin_.View(device); + + auto process_view = [&](auto &&view) { + common::ParallelFor(view.Size(), n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &feats = feats_tloc[tid]; + if (feats.Size() == 0) { + feats.Init(model.learner_model_param->num_feature); + } + auto &this_tree_contribs = contribs_tloc[tid]; + auto row_idx = view.base_rowid + i; + auto n_valid = view.DoFill(i, feats.Data().data()); + feats.HasMissing(n_valid != feats.Size()); + for (bst_target_t gid = 0; gid < n_groups; ++gid) { + float *p_contribs = &contribs[(row_idx * n_groups + gid) * ncolumns]; + for (bst_tree_t j = 0; j < tree_end; ++j) { + if (h_tree_groups[j] != gid) { + continue; + } + std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); + auto const sc_tree = model.trees[j]->HostScView(); + CalculateApproxContributions(sc_tree, feats, &mean_values[j], &this_tree_contribs); + for (size_t ci = 0; ci < ncolumns; ++ci) { + p_contribs[ci] += + this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]); + } + } + if (base_margin.Size() != 0) { + CHECK_EQ(base_margin.Shape(1), n_groups); + p_contribs[ncolumns - 1] += base_margin(row_idx, gid); + } else { + p_contribs[ncolumns - 1] += base_score(gid); + } + } + feats.Drop(); + }); + }; + + LaunchShap(ctx, p_fmat, model, process_view); +} + +void ShapInteractionValues(Context const *ctx, DMatrix *p_fmat, + HostDeviceVector *out_contribs, gbm::GBTreeModel const &model, + bst_tree_t tree_end, std::vector const *tree_weights, + bool approximate) { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "Predict interaction contribution" << MTNotImplemented(); + CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " + "column-wise data split is not yet implemented."; + MetaInfo const &info = p_fmat->Info(); + auto const ngroup = model.learner_model_param->num_output_group; + auto const ncolumns = model.learner_model_param->num_feature; + const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); + const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1); + const unsigned crow_chunk = ngroup * (ncolumns + 1); + + // allocate space for (number of features^2) times the number of rows and tmp off/on contribs + std::vector &contribs = out_contribs->HostVector(); + contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1)); + HostDeviceVector contribs_off_hdv(info.num_row_ * ngroup * (ncolumns + 1)); + auto &contribs_off = contribs_off_hdv.HostVector(); + HostDeviceVector contribs_on_hdv(info.num_row_ * ngroup * (ncolumns + 1)); + auto &contribs_on = contribs_on_hdv.HostVector(); + HostDeviceVector contribs_diag_hdv(info.num_row_ * ngroup * (ncolumns + 1)); + auto &contribs_diag = contribs_diag_hdv.HostVector(); + + // Compute the difference in effects when conditioning on each of the features on and off + // see: Axiomatic characterizations of probabilistic and + // cardinal-probabilistic interaction indices + if (approximate) { + ApproxFeatureImportance(ctx, p_fmat, &contribs_diag_hdv, model, tree_end, tree_weights); + } else { + ShapValues(ctx, p_fmat, &contribs_diag_hdv, model, tree_end, tree_weights, 0, 0); + } + for (size_t i = 0; i < ncolumns + 1; ++i) { + if (approximate) { + ApproxFeatureImportance(ctx, p_fmat, &contribs_off_hdv, model, tree_end, tree_weights); + ApproxFeatureImportance(ctx, p_fmat, &contribs_on_hdv, model, tree_end, tree_weights); + } else { + ShapValues(ctx, p_fmat, &contribs_off_hdv, model, tree_end, tree_weights, -1, i); + ShapValues(ctx, p_fmat, &contribs_on_hdv, model, tree_end, tree_weights, 1, i); + } + + for (size_t j = 0; j < info.num_row_; ++j) { + for (std::remove_const_t l = 0; l < ngroup; ++l) { + const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1); + const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1); + contribs[o_offset + i] = 0; + for (size_t k = 0; k < ncolumns + 1; ++k) { + // fill in the diagonal with additive effects, and off-diagonal with the interactions + if (k == i) { + contribs[o_offset + i] += contribs_diag[c_offset + k]; + } else { + contribs[o_offset + k] = (contribs_on[c_offset + k] - contribs_off[c_offset + k]) / 2.0; + contribs[o_offset + i] -= contribs[o_offset + k]; + } + } + } + } + } +} +} // namespace cpu_impl +} // namespace xgboost::interpretability diff --git a/src/predictor/interpretability/shap.cu b/src/predictor/interpretability/shap.cu new file mode 100644 index 000000000000..e98702b03252 --- /dev/null +++ b/src/predictor/interpretability/shap.cu @@ -0,0 +1,423 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include // for proclaim_return_type +#include // for swap +#include // for variant +#include +#include +#include +#include +#include +#include + +#include "../../common/categorical.h" +#include "../../common/common.h" +#include "../../common/cuda_context.cuh" // for CUDAContext +#include "../../common/cuda_rt_utils.h" // for SetDevice +#include "../../common/device_helpers.cuh" +#include "../../common/error_msg.h" +#include "../../common/nvtx_utils.h" +#include "../../data/batch_utils.h" // for StaticBatch +#include "../../data/cat_container.cuh" // for EncPolicy, MakeCatAccessor +#include "../../data/cat_container.h" // for NoOpAccessor +#include "../../data/ellpack_page.cuh" +#include "../../gbm/gbtree_model.h" +#include "../../tree/tree_view.h" +#include "../gbtree_view.h" +#include "../gpu_data_accessor.cuh" +#include "../predict_fn.h" // for GetTreeLimit +#include "shap.h" +#include "xgboost/data.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/linalg.h" // for UnravelIndex +#include "xgboost/logging.h" +#include "xgboost/multi_target_tree_model.h" // for MTNotImplemented + +namespace xgboost::interpretability::cuda_impl { +namespace { +using predictor::EllpackLoader; +using predictor::GBTreeModelView; +using predictor::SparsePageLoaderNoShared; +using predictor::SparsePageView; +using ::xgboost::cuda_impl::StaticBatch; + +using TreeViewVar = cuda::std::variant; + +struct CopyViews { + Context const* ctx; + explicit CopyViews(Context const* ctx) : ctx{ctx} {} + + void operator()(dh::DeviceUVector* p_dst, std::vector&& src) { + xgboost_NVTX_FN_RANGE(); + p_dst->resize(src.size()); + auto d_dst = dh::ToSpan(*p_dst); + dh::safe_cuda(cudaMemcpyAsync(d_dst.data(), src.data(), d_dst.size_bytes(), cudaMemcpyDefault, + ctx->CUDACtx()->Stream())); + } +}; + +using DeviceModel = GBTreeModelView; + +struct ShapSplitCondition { + ShapSplitCondition() = default; + XGBOOST_DEVICE + ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, bool is_missing_branch, + common::CatBitField cats) + : feature_lower_bound(feature_lower_bound), + feature_upper_bound(feature_upper_bound), + is_missing_branch(is_missing_branch), + categories{std::move(cats)} { + assert(feature_lower_bound <= feature_upper_bound); + } + + float feature_lower_bound; + float feature_upper_bound; + common::CatBitField categories; + bool is_missing_branch; + + [[nodiscard]] XGBOOST_DEVICE bool EvaluateSplit(float x) const { + if (isnan(x)) { + return is_missing_branch; + } + if (categories.Capacity() != 0) { + auto cat = static_cast(x); + return categories.Check(cat); + } else { + return x >= feature_lower_bound && x < feature_upper_bound; + } + } + + XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l, + common::CatBitField r) { + if (l.Data() == r.Data()) { + return l; + } + if (l.Capacity() > r.Capacity()) { + cuda::std::swap(l, r); + } + auto l_bits = l.Bits(); + auto r_bits = r.Bits(); + auto n_bits = l_bits.size() < r_bits.size() ? l_bits.size() : r_bits.size(); + for (size_t i = 0; i < n_bits; ++i) { + l_bits[i] &= r_bits[i]; + } + return l; + } + + XGBOOST_DEVICE void Merge(ShapSplitCondition other) { + if (categories.Capacity() != 0 || other.categories.Capacity() != 0) { + categories = Intersect(categories, other.categories); + } else { + feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); + feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + } + is_missing_branch = is_missing_branch && other.is_missing_branch; + } +}; + +struct PathInfo { + std::size_t length; + bst_node_t nidx; + bst_tree_t tree_idx; + + [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return nidx != -1; } +}; +static_assert(sizeof(PathInfo) == 16); + +auto MakeTreeSegments(Context const* ctx, bst_tree_t tree_begin, bst_tree_t tree_end, + gbm::GBTreeModel const& model) { + auto tree_segments = HostDeviceVector({}, ctx->Device()); + auto& h_tree_segments = tree_segments.HostVector(); + h_tree_segments.reserve((tree_end - tree_begin) + 1); + std::size_t sum = 0; + h_tree_segments.push_back(sum); + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + auto const& p_tree = model.trees.at(tree_idx); + CHECK(!p_tree->IsMultiTarget()) << " SHAP " << MTNotImplemented(); + sum += p_tree->Size(); + h_tree_segments.push_back(sum); + } + return tree_segments; +} + +void ExtractPaths(Context const* ctx, + dh::device_vector>* paths, + gbm::GBTreeModel const& h_model, DeviceModel const& d_model, + dh::device_vector* path_categories) { + curt::SetDevice(ctx->Ordinal()); + + dh::caching_device_vector info(d_model.n_nodes); + auto d_trees = d_model.Trees(); + auto tree_segments = MakeTreeSegments(ctx, d_model.tree_begin, d_model.tree_end, h_model); + CHECK_EQ(tree_segments.ConstHostVector().back(), d_model.n_nodes); + auto d_tree_segments = tree_segments.ConstDeviceSpan(); + + auto path_it = dh::MakeIndexTransformIter( + cuda::proclaim_return_type([=] __device__(size_t idx) -> PathInfo { + bst_tree_t const tree_idx = dh::SegmentId(d_tree_segments, idx); + bst_node_t const nidx = idx - d_tree_segments[tree_idx]; + auto const& tree = cuda::std::get(d_trees[tree_idx]); + if (!tree.IsLeaf(nidx) || tree.IsDeleted(nidx)) { + return PathInfo{0, -1, 0}; + } + std::size_t path_length = 1; + auto iter_nidx = nidx; + while (!tree.IsRoot(iter_nidx)) { + iter_nidx = tree.Parent(iter_nidx); + path_length++; + } + return PathInfo{path_length, nidx, tree_idx}; + })); + auto end = thrust::copy_if( + ctx->CUDACtx()->CTP(), path_it, path_it + d_model.n_nodes, info.begin(), + cuda::proclaim_return_type([=] __device__(PathInfo const& e) { return e.IsLeaf(); })); + + info.resize(end - info.begin()); + using LenT = decltype(std::declval().length); + auto length_iterator = dh::MakeTransformIterator( + info.begin(), cuda::proclaim_return_type( + [=] __device__(PathInfo const& info) { return info.length; })); + dh::caching_device_vector path_segments(info.size() + 1); + thrust::fill_n(ctx->CUDACtx()->CTP(), path_segments.begin(), 1, std::size_t{0}); + thrust::inclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size(), + path_segments.begin() + 1); + + paths->resize(path_segments.back()); + + auto d_paths = dh::ToSpan(*paths); + auto d_info = info.data().get(); + auto d_tree_groups = d_model.tree_groups; + auto d_path_segments = path_segments.data().get(); + + std::size_t max_cat = 0; + if (std::any_of(h_model.trees.cbegin(), h_model.trees.cend(), + [](auto const& p_tree) { return p_tree->HasCategoricalSplit(); })) { + auto max_elem_it = dh::MakeIndexTransformIter([=] __device__(std::size_t i) -> std::size_t { + auto tree_idx = dh::SegmentId(d_tree_segments, i); + auto nidx = i - d_tree_segments[tree_idx]; + return cuda::std::get(d_trees[tree_idx]) + .GetCategoriesMatrix() + .node_ptr[nidx] + .size; + }); + auto max_cat_it = + thrust::max_element(ctx->CUDACtx()->CTP(), max_elem_it, max_elem_it + d_model.n_nodes); + dh::CachingDeviceUVector d_max_cat(1); + auto s_max_cat = dh::ToSpan(d_max_cat); + dh::LaunchN(1, ctx->CUDACtx()->Stream(), + [=] __device__(std::size_t) { s_max_cat[0] = *max_cat_it; }); + dh::safe_cuda( + cudaMemcpy(&max_cat, s_max_cat.data(), s_max_cat.size_bytes(), cudaMemcpyDeviceToHost)); + CHECK_GE(max_cat, 1); + path_categories->resize(max_cat * paths->size()); + } + + common::Span d_path_categories = dh::ToSpan(*path_categories); + + dh::LaunchN(info.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { + auto path_info = d_info[idx]; + auto tree = cuda::std::get(d_trees[path_info.tree_idx]); + std::int32_t group = d_tree_groups[path_info.tree_idx]; + auto child_nidx = path_info.nidx; + + float v = tree.LeafValue(child_nidx); + const float inf = std::numeric_limits::infinity(); + size_t output_position = d_path_segments[idx + 1] - 1; + + while (!tree.IsRoot(child_nidx)) { + auto parent_nidx = tree.Parent(child_nidx); + double child_cover = tree.SumHess(child_nidx); + double parent_cover = tree.SumHess(parent_nidx); + double zero_fraction = child_cover / parent_cover; + + bool is_left_path = tree.LeftChild(parent_nidx) == child_nidx; + bool is_missing_path = (!tree.DefaultLeft(parent_nidx) && !is_left_path) || + (tree.DefaultLeft(parent_nidx) && is_left_path); + + float lower_bound = -inf; + float upper_bound = inf; + common::CatBitField bits; + if (common::IsCat(tree.cats.split_type, tree.Parent(child_nidx))) { + auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat); + auto node_cats = tree.NodeCats(tree.Parent(child_nidx)); + SPAN_CHECK(path_cats.size() >= node_cats.size()); + for (size_t i = 0; i < node_cats.size(); ++i) { + path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i]; + } + bits = common::CatBitField{path_cats}; + } else { + lower_bound = is_left_path ? -inf : tree.SplitCond(parent_nidx); + upper_bound = is_left_path ? tree.SplitCond(parent_nidx) : inf; + } + d_paths[output_position--] = gpu_treeshap::PathElement{ + idx, tree.SplitIndex(parent_nidx), + group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits}, + zero_fraction, v}; + + child_nidx = parent_nidx; + } + d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v}; + }); +} + +template +void DispatchByBatchLoader(Context const* ctx, DMatrix* p_fmat, bst_feature_t n_features, + EncAccessor acc, Fn&& fn) { + using AccT = std::decay_t; + if (p_fmat->PageExists()) { + for (auto& page : p_fmat->GetBatches()) { + SparsePageView batch{ctx, page, n_features}; + auto loader = SparsePageLoaderNoShared{batch, acc}; + fn(std::move(loader), page.base_rowid); + } + } else { + p_fmat->Info().feature_types.SetDevice(ctx->Device()); + auto feature_types = p_fmat->Info().feature_types.ConstDeviceSpan(); + + for (auto const& page : p_fmat->GetBatches(ctx, StaticBatch(true))) { + page.Impl()->Visit(ctx, feature_types, [&](auto&& batch) { + using BatchT = std::remove_reference_t; + auto loader = EllpackLoader{batch, + /*use_shared=*/false, + n_features, + batch.NumRows(), + std::numeric_limits::quiet_NaN(), + AccT{acc}}; + fn(std::move(loader), batch.base_rowid); + }); + } + } +} + +template +void LaunchShap(Context const* ctx, DMatrix* p_fmat, enc::DeviceColumnsView const& new_enc, + gbm::GBTreeModel const& model, Fn&& fn) { + auto n_features = model.learner_model_param->num_feature; + if (model.Cats() && model.Cats()->HasCategorical() && new_enc.HasCategorical()) { + auto [acc, mapping] = ::xgboost::cuda_impl::MakeCatAccessor(ctx, new_enc, model.Cats()); + DispatchByBatchLoader(ctx, p_fmat, n_features, std::move(acc), fn); + } else { + DispatchByBatchLoader(ctx, p_fmat, n_features, NoOpAccessor{}, fn); + } +} +} // namespace + +void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int, unsigned) { + xgboost_NVTX_FN_RANGE(); + StringView not_implemented{ + "contribution is not implemented in the GPU predictor, use CPU instead."}; + if (tree_weights != nullptr) { + LOG(FATAL) << "Dart booster feature " << not_implemented; + } + CHECK(!p_fmat->Info().IsColumnSplit()) + << "Predict contribution support for column-wise data split is not yet implemented."; + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); + out_contribs->SetDevice(ctx->Device()); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + + const int ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + size_t contributions_columns = model.learner_model_param->num_feature + 1; + auto dim_size = contributions_columns * model.learner_model_param->num_output_group; + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); + out_contribs->Fill(0.0f); + auto phis = out_contribs->DeviceSpan(); + + dh::device_vector> device_paths; + DeviceModel d_model{ctx->Device(), model, true, 0, tree_end, CopyViews{ctx}}; + + auto new_enc = + p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; + + dh::device_vector categories; + ExtractPaths(ctx, &device_paths, model, d_model, &categories); + + LaunchShap(ctx, p_fmat, new_enc, model, [&](auto&& loader, bst_idx_t base_rowid) { + auto begin = dh::tbegin(phis) + base_rowid * dim_size; + gpu_treeshap::GPUTreeShap>( + loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + }); + + p_fmat->Info().base_margin_.SetDevice(ctx->Device()); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); + + auto base_score = model.learner_model_param->BaseScore(ctx); + bst_idx_t n_samples = p_fmat->Info().num_row_; + dh::LaunchN(n_samples * ngroup, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { + auto [_, gid] = linalg::UnravelIndex(idx, n_samples, ngroup); + phis[(idx + 1) * contributions_columns - 1] += margin.empty() ? base_score(gid) : margin[idx]; + }); +} + +void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + bool approximate) { + xgboost_NVTX_FN_RANGE(); + std::string not_implemented{"contribution is not implemented in GPU predictor, use cpu instead."}; + if (approximate) { + LOG(FATAL) << "Approximated " << not_implemented; + } + if (tree_weights != nullptr) { + LOG(FATAL) << "Dart booster feature " << not_implemented; + } + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); + out_contribs->SetDevice(ctx->Device()); + tree_end = predictor::GetTreeLimit(model.trees, tree_end); + + const int ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + size_t contributions_columns = model.learner_model_param->num_feature + 1; + auto dim_size = + contributions_columns * contributions_columns * model.learner_model_param->num_output_group; + out_contribs->Resize(p_fmat->Info().num_row_ * dim_size); + out_contribs->Fill(0.0f); + auto phis = out_contribs->DeviceSpan(); + + dh::device_vector> device_paths; + DeviceModel d_model{ctx->Device(), model, true, 0, tree_end, CopyViews{ctx}}; + + dh::device_vector categories; + ExtractPaths(ctx, &device_paths, model, d_model, &categories); + auto new_enc = + p_fmat->Cats()->NeedRecode() ? p_fmat->Cats()->DeviceView(ctx) : enc::DeviceColumnsView{}; + + LaunchShap(ctx, p_fmat, new_enc, model, [&](auto&& loader, bst_idx_t base_rowid) { + auto begin = dh::tbegin(phis) + base_rowid * dim_size; + gpu_treeshap::GPUTreeShapInteractions>( + loader, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); + }); + + p_fmat->Info().base_margin_.SetDevice(ctx->Device()); + const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan(); + + auto base_score = model.learner_model_param->BaseScore(ctx); + size_t n_features = model.learner_model_param->num_feature; + bst_idx_t n_samples = p_fmat->Info().num_row_; + dh::LaunchN(n_samples * ngroup, ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { + auto [ridx, gidx] = linalg::UnravelIndex(idx, n_samples, ngroup); + phis[gpu_treeshap::IndexPhiInteractions(ridx, ngroup, gidx, n_features, n_features, + n_features)] += + margin.empty() ? base_score(gidx) : margin[idx]; + }); +} + +void ApproxFeatureImportance(Context const* ctx, DMatrix*, HostDeviceVector*, + gbm::GBTreeModel const&, bst_tree_t, std::vector const*) { + StringView not_implemented{ + "contribution is not implemented in the GPU predictor, use CPU instead."}; + LOG(FATAL) << "Approximated " << not_implemented; +} +} // namespace xgboost::interpretability::cuda_impl diff --git a/src/predictor/interpretability/shap.h b/src/predictor/interpretability/shap.h new file mode 100644 index 000000000000..2c32d1d84554 --- /dev/null +++ b/src/predictor/interpretability/shap.h @@ -0,0 +1,89 @@ +/** + * Copyright 2017-2026, XGBoost Contributors + */ +#pragma once + +#include // for vector + +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost::gbm { +struct GBTreeModel; +} // namespace xgboost::gbm + +namespace xgboost::interpretability { +namespace cpu_impl { +void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int condition, unsigned condition_feature); + +void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights); + +void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + bool approximate); +} // namespace cpu_impl + +#if defined(XGBOOST_USE_CUDA) +namespace cuda_impl { +void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int condition, unsigned condition_feature); +void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights); +void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, gbm::GBTreeModel const& model, + bst_tree_t tree_end, std::vector const* tree_weights, + bool approximate); +} // namespace cuda_impl +#endif // defined(XGBOOST_USE_CUDA) + +inline void ShapValues(Context const* ctx, DMatrix* p_fmat, HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, int condition, + unsigned condition_feature) { +#if defined(XGBOOST_USE_CUDA) + if (ctx->IsCUDA()) { + cuda_impl::ShapValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, + condition_feature); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + cpu_impl::ShapValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, condition, + condition_feature); +} + +inline void ApproxFeatureImportance(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights) { +#if defined(XGBOOST_USE_CUDA) + if (ctx->IsCUDA()) { + cuda_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + cpu_impl::ApproxFeatureImportance(ctx, p_fmat, out_contribs, model, tree_end, tree_weights); +} + +inline void ShapInteractionValues(Context const* ctx, DMatrix* p_fmat, + HostDeviceVector* out_contribs, + gbm::GBTreeModel const& model, bst_tree_t tree_end, + std::vector const* tree_weights, bool approximate) { +#if defined(XGBOOST_USE_CUDA) + if (ctx->IsCUDA()) { + cuda_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, + approximate); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + cpu_impl::ShapInteractionValues(ctx, p_fmat, out_contribs, model, tree_end, tree_weights, + approximate); +} +} // namespace xgboost::interpretability diff --git a/tests/cpp/predictor/test_shap.cc b/tests/cpp/predictor/test_shap.cc index 9368ec1ee0d8..55029871facd 100644 --- a/tests/cpp/predictor/test_shap.cc +++ b/tests/cpp/predictor/test_shap.cc @@ -7,11 +7,19 @@ #include #include // for DMatrix #include // for HostDeviceVector +#include // for Json #include // for Learner +#include // for Vector +#include // for ObjFunction +#include #include // for unique_ptr +#include #include // for to_string +#include "../../../src/common/param_array.h" +#include "../../../src/gbm/gbtree_model.h" +#include "../../../src/predictor/interpretability/shap.h" #include "../helpers.h" namespace xgboost { @@ -42,6 +50,65 @@ Args BaseParams(Context const* ctx, std::string objective, std::string max_depth {"colsample_bytree", "1"}, {"device", ctx->IsSycl() ? "cpu" : ctx->DeviceName()}}; } + +std::unique_ptr LoadGBTreeModel(Learner* learner, Context const* ctx, + Args const& model_args, + LearnerModelParam* out_param) { + Json model{Object{}}; + learner->SaveModel(&model); + + auto const& model_obj = get(model); + auto const& learner_obj = get(model_obj.at("learner")); + auto const& lmp = get(learner_obj.at("learner_model_param")); + + auto get_or = [&](char const* key, std::string dft) { + auto it = lmp.find(key); + return it == lmp.cend() ? dft : get(it->second); + }; + auto const& num_feature = get_or("num_feature", "0"); + auto const& num_class = get_or("num_class", "0"); + auto const& num_target = get_or("num_target", "1"); + auto const& base_score_str = get_or("base_score", "0"); + + common::ParamArray base_score_arr{"base_score"}; + std::stringstream ss; + ss << base_score_str; + ss >> base_score_arr; + + std::size_t shape[1]{base_score_arr.size()}; + linalg::Vector base_score_vec{shape, ctx->Device()}; + auto& h_base = base_score_vec.Data()->HostVector(); + h_base.assign(base_score_arr.cbegin(), base_score_arr.cend()); + + std::string objective{"reg:squarederror"}; + for (auto const& kv : model_args) { + if (kv.first == "objective") { + objective = kv.second; + break; + } + } + auto obj = std::unique_ptr(ObjFunction::Create(objective, ctx)); + obj->Configure(model_args); + obj->ProbToMargin(&base_score_vec); + // Keep both host/device views readable, matching LearnerModelParam invariants. + std::as_const(base_score_vec).HostView(); + if (!ctx->Device().IsCPU()) { + std::as_const(base_score_vec).View(ctx->Device()); + } + + auto n_features = static_cast(std::stol(num_feature)); + auto n_classes = static_cast(std::stol(num_class)); + auto n_targets = static_cast(std::stol(num_target)); + auto n_groups = static_cast(std::max(n_classes, n_targets)); + LearnerModelParam tmp{n_features, std::move(base_score_vec), n_groups, n_targets, + MultiStrategy::kOneOutputPerTree}; + out_param->Copy(tmp); + + auto gbtree = std::make_unique(out_param, ctx); + auto const& gbm_obj = get(learner_obj.at("gradient_booster")); + gbtree->LoadModel(gbm_obj.at("model")); + return gbtree; +} } // namespace std::vector BuildShapTestCases(Context const* ctx) { @@ -50,31 +117,31 @@ std::vector BuildShapTestCases(Context const* ctx) { { // small dense, shallow tree - auto dmat = RandomDataGenerator(32, 6, 0.0).Device(device).GenerateDMatrix(); + auto dmat = RandomDataGenerator(16, 4, 0.0).Device(device).GenerateDMatrix(); SetLabels(dmat.get(), 1); cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "2")); } { // medium dense training DMatrix, moderate depth - auto dmat = RandomDataGenerator(512, 10, 0.0).Device(device).GenerateDMatrix(true); + auto dmat = RandomDataGenerator(64, 6, 0.0).Device(device).GenerateDMatrix(true); SetLabels(dmat.get(), 1); - cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "6")); + cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "4")); } { // quantile DMatrix with explicit bins, deeper tree auto dmat = - RandomDataGenerator(2048, 12, 0.0).Bins(64).Device(device).GenerateQuantileDMatrix(false); + RandomDataGenerator(128, 8, 0.0).Bins(32).Device(device).GenerateQuantileDMatrix(false); SetLabels(dmat.get(), 1); - auto args = BaseParams(ctx, "reg:squarederror", "8"); - args.emplace_back("max_bin", "64"); + auto args = BaseParams(ctx, "reg:squarederror", "5"); + args.emplace_back("max_bin", "32"); cases.emplace_back(dmat, std::move(args)); } { // external memory quantile DMatrix, moderate depth - bst_bin_t max_bin{64}; + bst_bin_t max_bin{32}; auto dmat = RandomDataGenerator(4096, 10, 0.0) .Batches(2) .Bins(max_bin) @@ -88,32 +155,30 @@ std::vector BuildShapTestCases(Context const* ctx) { { // external memory sparse page DMatrix, moderate depth - auto dmat = RandomDataGenerator(4096, 10, 0.0) + auto dmat = RandomDataGenerator(256, 8, 0.0) .Batches(2) .Device(device) .GenerateSparsePageDMatrix("shap_extmem", true); SetLabels(dmat.get(), 1); - cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "6")); + cases.emplace_back(dmat, BaseParams(ctx, "reg:squarederror", "4")); } { // multi-class dense training DMatrix, medium depth bst_target_t n_classes{3}; - auto dmat = RandomDataGenerator(256, 8, 0.0) - .Classes(n_classes) - .Device(device) - .GenerateDMatrix(true); + auto dmat = + RandomDataGenerator(64, 6, 0.0).Classes(n_classes).Device(device).GenerateDMatrix(true); SetLabels(dmat.get(), n_classes); - auto args = BaseParams(ctx, "multi:softprob", "4"); + auto args = BaseParams(ctx, "multi:softprob", "3"); args.emplace_back("num_class", std::to_string(n_classes)); cases.emplace_back(dmat, std::move(args)); } { - // large dense, deeper tree and classification objective - auto dmat = RandomDataGenerator(10000, 12, 0.0).Device(device).GenerateDMatrix(); + // compact dense classification case to keep runtime bounded + auto dmat = RandomDataGenerator(256, 8, 0.0).Device(device).GenerateDMatrix(); SetLabels(dmat.get(), 1); - cases.emplace_back(dmat, BaseParams(ctx, "binary:logistic", "10")); + cases.emplace_back(dmat, BaseParams(ctx, "binary:logistic", "4")); } return cases; @@ -127,7 +192,7 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { std::unique_ptr learner{Learner::Create({p_dmat})}; learner->SetParams(model_args); learner->Configure(); - for (size_t i = 0; i < 5; ++i) { + for (size_t i = 0; i < 2; ++i) { learner->UpdateOneIter(i, p_dmat); } @@ -135,15 +200,18 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) { learner->Predict(p_dmat, true, &margin_predt, 0, 0, false, false, false, false, false); size_t const n_outputs = margin_predt.HostVector().size() / kRows; + LearnerModelParam mparam; + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), model_args, &mparam); + HostDeviceVector shap_values; - learner->Predict(p_dmat, false, &shap_values, 0, 0, false, false, true, false, false); + interpretability::ShapValues(dmat->Ctx(), p_dmat.get(), &shap_values, *gbtree, 0, nullptr, 0, 0); ASSERT_EQ(shap_values.HostVector().size(), kRows * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_values, margin_predt); HostDeviceVector shap_interactions; - learner->Predict(p_dmat, false, &shap_interactions, 0, 0, false, false, false, false, true); - ASSERT_EQ(shap_interactions.HostVector().size(), - kRows * (kCols + 1) * (kCols + 1) * n_outputs); + interpretability::ShapInteractionValues(dmat->Ctx(), p_dmat.get(), &shap_interactions, *gbtree, 0, + {}, false); + ASSERT_EQ(shap_interactions.HostVector().size(), kRows * (kCols + 1) * (kCols + 1) * n_outputs); CheckShapAdditivity(kRows, kCols, shap_interactions, margin_predt); } @@ -207,8 +275,12 @@ TEST(Predictor, ApproxContribsBasic) { HostDeviceVector margin_predt; learner->Predict(dmat, true, &margin_predt, 0, 0, false, false, false, false, false); + LearnerModelParam mparam; + auto gbtree = LoadGBTreeModel(learner.get(), dmat->Ctx(), args, &mparam); + HostDeviceVector approx_contribs; - learner->Predict(dmat, false, &approx_contribs, 0, 0, false, false, true, true, false); + interpretability::ApproxFeatureImportance(dmat->Ctx(), dmat.get(), &approx_contribs, *gbtree, 0, + {}); auto const& h_margin = margin_predt.ConstHostVector(); auto const& h_contribs = approx_contribs.ConstHostVector();