Skip to content
3 changes: 2 additions & 1 deletion src/common/hist_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, Me
// approximation here is reasonably accurate. It doesn't hurt accuracy since the
// estimated n_samples must be greater or equal to the actual n_samples thanks to the
// dense assumption.
auto approx_n_samples = std::max(sketch_batch_num_elements / num_cols, bst_idx_t{1});
auto approx_n_samples =
std::max(common::DivRoundUp(sketch_batch_num_elements, num_cols), bst_idx_t{1});
num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, approx_n_samples);
bst_idx_t end =
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
Expand Down
23 changes: 1 addition & 22 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,22 +492,8 @@ auto HostSketchContainer::AllReduce(Context const *ctx, MetaInfo const &info,
}

void AddCutPoints(WQSummaryContainer const &summary, size_t max_bin, HistogramCuts *cuts) {
size_t required_cuts = std::min(summary.Size(), static_cast<size_t>(max_bin));
auto &cut_values = cuts->cut_values_.HostVector();
auto const entries = summary.Entries();
// Use raw pointer in the cut extraction loop to avoid per-access bounds checks.
auto const *summary_data = entries.data();
// summary[0] is the observed minimum; the first bin lower bound is implicit.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary_data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
auto const cpt = !entries.empty() ? entries.back().value : 1e-5f;
// This must be bigger than the last observed cut value.
auto const last = cpt + (std::fabs(cpt) + 1e-5f);
cut_values.push_back(last);
QueryCutValues(summary, max_bin, [&](float cpt) { cut_values.push_back(cpt); });
}

void AddCategories(std::set<float> const &categories, float *max_cat, HistogramCuts *cuts) {
Expand Down Expand Up @@ -551,13 +537,6 @@ HistogramCuts HostSketchContainer::MakeCuts(Context const *ctx, MetaInfo const &
}

auto &h_cut_ptrs = p_cuts->cut_ptrs_.HostVector();
// Prune size down to max_bins + 1 (reserve one extra for the max value)
// before extracting cut points.
ParallelFor(numeric_features.size(), n_threads_, Sched::Guided(), [&](size_t idx) {
auto fidx = numeric_features[idx];
reduced_numerical.at(fidx).SetPrune(max_bins_ + 1); // reserve one extra for the max value
});

float max_cat{-1.f};
for (size_t fid = 0; fid < reduced_numerical.size(); ++fid) {
size_t max_num_bins = std::min(reduced_numerical[fid].Size(), static_cast<size_t>(max_bins_));
Expand Down
163 changes: 32 additions & 131 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "hist_util.h"
#include "quantile.cuh"
#include "quantile.h"
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/span.h"

namespace xgboost::common {
Expand Down Expand Up @@ -663,19 +662,6 @@ void SketchContainer::AllReduce(Context const *ctx, bool is_column_split) {
LOG(FATAL) << "Distributed GPU quantile sketch reduction requires NCCL support.";
}

namespace {
struct InvalidCatOp {
Span<SketchEntry const> values;
Span<size_t const> ptrs;
Span<FeatureType const> ft;

XGBOOST_DEVICE bool operator()(size_t i) const {
auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && InvalidCat(values[i].value);
}
};
} // anonymous namespace

HistogramCuts SketchContainer::MakeCuts(Context const *ctx, bool is_column_split) {
curt::SetDevice(ctx->Ordinal());
HistogramCuts cuts{num_columns_};
Expand All @@ -685,133 +671,48 @@ HistogramCuts SketchContainer::MakeCuts(Context const *ctx, bool is_column_split
this->AllReduce(ctx, is_column_split);

timer_.Start(__func__);
// Prune to final number of bins.
this->Prune(ctx, num_bins_ + 1);

// Set up inputs
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();

auto const in_cut_values = dh::ToSpan(this->entries_);

// Set up output ptr
p_cuts->cut_ptrs_.SetDevice(ctx->Device());
auto &h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
h_out_columns_ptr.front() = 0;
auto const &h_feature_types = this->feature_types_.ConstHostSpan();
h_out_columns_ptr.assign(num_columns_ + 1, 0);
auto &h_out_cut_values = p_cuts->cut_values_.HostVector();
h_out_cut_values.clear();

auto d_ft = feature_types_.ConstDeviceSpan();
auto const &h_in_columns_ptr = this->columns_ptr_.ConstHostVector();
std::vector<SketchEntry> h_entries(this->entries_.size());
dh::CopyDeviceSpanToVector(&h_entries, dh::ToSpan(this->entries_));
auto const &h_feature_types = this->feature_types_.ConstHostSpan();

std::vector<SketchEntry> max_values;
// TODO(rory): Port query-based cut extraction back onto the device and remove this temporary
// host-side extraction path once the GPU semantics are validated.
float max_cat{-1.f};
if (has_categorical_) {
auto key_it = dh::MakeTransformIterator<bst_feature_t>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t {
return dh::SegmentId(d_in_columns_ptr, i);
});
auto invalid_op = InvalidCatOp{in_cut_values, d_in_columns_ptr, d_ft};
auto val_it = dh::MakeTransformIterator<SketchEntry>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto fidx = dh::SegmentId(d_in_columns_ptr, i);
auto v = in_cut_values[i];
if (IsCat(d_ft, fidx)) {
if (invalid_op(i)) {
// use inf to indicate invalid value, this way we can keep it as in
// indicator in the reduce operation as it's always the greatest value.
v.value = std::numeric_limits<float>::infinity();
}
}
return v;
});
CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1);
max_values.resize(d_in_columns_ptr.size() - 1);

// In some cases (e.g. column-wise data split), we may have empty columns, so we need to keep
// track of the unique keys (feature indices) after the thrust::reduce_by_key` call.
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1);
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
auto new_end = thrust::reduce_by_key(
ctx->CUDACtx()->CTP(), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
d_max_values.begin(), thrust::equal_to<bst_feature_t>{},
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
d_max_keys.erase(new_end.first, d_max_keys.end());
d_max_values.erase(new_end.second, d_max_values.end());

// The device vector needs to be initialized explicitly since we may have some missing columns.
SketchEntry default_entry{};
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1,
default_entry);
thrust::scatter(ctx->CUDACtx()->CTP(), d_max_values.begin(), d_max_values.end(),
d_max_keys.begin(), d_max_results.begin());
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results));
auto max_it = MakeIndexTransformIter([&](auto i) {
if (IsCat(h_feature_types, i)) {
return max_values[i].value;
}
return -1.f;
});
max_cat = *std::max_element(max_it, max_it + max_values.size());
if (std::isinf(max_cat)) {
InvalidCategory();
}
}

// Set up output cuts
WQSummaryContainer summary;
for (bst_feature_t i = 0; i < num_columns_; ++i) {
size_t column_size = std::max(static_cast<size_t>(1ul), this->Column(i).size());
auto begin = h_in_columns_ptr[i];
auto end = h_in_columns_ptr[i + 1];
auto column = Span<SketchEntry const>{h_entries.data() + begin, end - begin};

if (IsCat(h_feature_types, i)) {
// column_size is the number of unique values in that feature.
CheckMaxCat(max_values[i].value, column_size);
h_out_columns_ptr[i + 1] = max_values[i].value + 1; // includes both max_cat and 0.
auto column_size = std::max(static_cast<std::size_t>(1), column.size());
auto feature_max = column.empty() ? 0.0f : column.back().value;
if (std::any_of(column.cbegin(), column.cend(),
[](auto const &entry) { return InvalidCat(entry.value); })) {
InvalidCategory();
}
CheckMaxCat(feature_max, column_size);
max_cat = std::max(max_cat, feature_max);
for (std::size_t cat = 0; cat <= static_cast<std::size_t>(feature_max); ++cat) {
h_out_cut_values.push_back(cat);
}
} else {
h_out_columns_ptr[i + 1] =
std::min(static_cast<size_t>(column_size), static_cast<size_t>(num_bins_));
summary.Reserve(column.size());
std::copy(column.cbegin(), column.cend(), summary.space.begin());
summary.SetSize(column.size());
QueryCutValues(summary, static_cast<std::size_t>(num_bins_),
[&](float cpt) { h_out_cut_values.push_back(cpt); });
}
h_out_columns_ptr[i + 1] = h_out_cut_values.size();
}
std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin());
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();

size_t total_bins = h_out_columns_ptr.back();
p_cuts->cut_values_.SetDevice(ctx->Device());
p_cuts->cut_values_.Resize(total_bins);
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();

dh::LaunchN(total_bins, [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
auto in_column = in_cut_values.subspan(
d_in_columns_ptr[column_id], d_in_columns_ptr[column_id + 1] - d_in_columns_ptr[column_id]);
auto out_column =
out_cut_values.subspan(d_out_columns_ptr[column_id],
d_out_columns_ptr[column_id + 1] - d_out_columns_ptr[column_id]);
idx -= d_out_columns_ptr[column_id];
if (in_column.size() == 0) {
// If the column is empty, we push a dummy value. It won't affect training as the
// column is empty, trees cannot split on it. This is just to be consistent with
// rest of the library.
if (idx == 0) {
out_column[0] = kRtEps;
assert(out_column.size() == 1);
}
return;
}

if (IsCat(d_ft, column_id)) {
out_column[idx] = idx;
return;
}

// Last thread is responsible for setting a value that's greater than other cuts.
if (idx == out_column.size() - 1) {
const bst_float cpt = in_column.back().value;
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5);
out_column[idx] = last;
return;
}
assert(idx + 1 < in_column.size());
out_column[idx] = in_column[idx + 1].value;
});

p_cuts->SetCategorical(this->has_categorical_, max_cat);
p_cuts->SetDevice(ctx->Device());
timer_.Stop(__func__);
return cuts;
}
Expand Down
96 changes: 96 additions & 0 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,60 @@ struct WQSummary {
dst_data[current_elements_++] = src_data[src_size - 1];
}
}

[[nodiscard]] Entry const &Query(double rank) const {
CHECK(!this->Empty());
auto const entries = this->Entries();
if (entries.size() == 1 || rank < entries.front().rmax) {
return entries.front();
}
if (rank >= entries.back().rmin) {
return entries.back();
}

auto rank2 = static_cast<double>(2.0) * rank;
auto it = std::upper_bound(entries.cbegin() + 1, entries.cend() - 1, rank2,
[](double lhs, Entry const &rhs) {
return lhs < static_cast<double>(rhs.rmin + rhs.rmax);
});
auto i = static_cast<std::size_t>(std::distance(entries.cbegin(), it) - 1);
if (rank2 < static_cast<double>(entries[i].RMinNext() + entries[i + 1].RMaxPrev())) {
return entries[i];
}
return entries[i + 1];
}

template <typename Fn>
void QueryRanks(std::size_t num_cuts, Fn &&fn) const {
CHECK(!this->Empty());
if (num_cuts <= 1) {
return;
}

auto const entries = this->Entries();
if (entries.size() == 1) {
for (std::size_t i = 1; i < num_cuts; ++i) {
fn(i, entries.front());
}
return;
}

auto total = static_cast<double>(entries.back().rmax);
std::size_t cursor = 0;
for (std::size_t i = 1; i < num_cuts; ++i) {
auto rank = static_cast<double>(i) * total / static_cast<double>(num_cuts);
auto rank2 = static_cast<double>(2.0) * rank;
while (cursor < entries.size() - 2 &&
rank2 >= static_cast<double>(entries[cursor + 1].rmin + entries[cursor + 1].rmax)) {
++cursor;
}
auto const &queried =
rank2 < static_cast<double>(entries[cursor].RMinNext() + entries[cursor + 1].RMaxPrev())
? entries[cursor]
: entries[cursor + 1];
fn(i, queried);
}
}
/*!
* \brief combine `other` into `this`.
*
Expand Down Expand Up @@ -449,9 +503,51 @@ struct WQSummaryContainer : public WQSummary<> {
}
};

template <typename Summary>
auto NextGreaterSummaryValue(Summary const &summary, float value) -> float {
auto const entries = summary.Entries();
auto it = std::upper_bound(entries.cbegin(), entries.cend(), value,
[](float lhs, auto const &rhs) { return lhs < rhs.value; });
if (it == entries.cend()) {
return value;
}
return it->value;
}

template <typename Summary, typename Fn>
void QueryCutValues(Summary const &summary, std::size_t max_bin, Fn &&fn) {
auto required_cuts = std::min(summary.Size(), max_bin);
auto const entries = summary.Entries();

if (summary.Size() <= max_bin) {
for (std::size_t i = 1; i < required_cuts; ++i) {
fn(entries[i].value);
}
} else {
auto last_cut = entries.front().value;
summary.QueryRanks(required_cuts, [&](std::size_t, auto const &queried) {
auto cpt = queried.value;
if (cpt <= last_cut) {
cpt = NextGreaterSummaryValue(summary, last_cut);
}
if (cpt > last_cut) {
fn(cpt);
last_cut = cpt;
}
});
}

auto cpt = !entries.empty() ? entries.back().value : 1e-5f;
fn(cpt + (std::fabs(cpt) + 1e-5f));
}

/*! \brief Weighted quantile sketch algorithm using merge/prune. */
class WQuantileSketch {
public:
// Sketch epsilon is approximately `1 / (kFactor * max_bin)` once `max_bin` limits the budget.
// Our current cut-rank measurements suggest an empirical constant of about 2 for the final
// emitted cuts, so the observed normalized cut error is about `2 / kFactor`. With
// `kFactor = 8`, that is roughly `0.25` bins of rank mass, i.e. about a quarter-bin offset.
static float constexpr kFactor = 8.0;

public:
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/common/test_hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ struct RankErrorSummary {
};

inline constexpr double kMaxNormalizedRankError = 2.0;
inline constexpr double kMaxWeightedNormalizedRankError = 15.0;
inline constexpr double kMaxWeightedNormalizedRankError = 10.0;

inline double DistanceToInterval(double target, double lo, double hi) {
if (target < lo) {
Expand Down
Loading
Loading