Skip to content

Commit 4b81e2f

Browse files
authored
Prune GPU sketch after each merge (#12105)
1 parent a1f9c11 commit 4b81e2f

File tree

8 files changed

+133
-66
lines changed

8 files changed

+133
-66
lines changed

src/common/hist_util.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t
4646
size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz, size_t num_bins,
4747
bool with_weights) {
4848
size_t peak = 0;
49+
auto cuts_bytes = RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry);
4950
// 0. Allocate cut pointer in quantile container by increasing: n_columns + 1
5051
size_t total = (num_columns + 1) * sizeof(SketchContainer::OffsetT);
5152
// 1. Copy and sort: 2 * bytes_per_element * shape
@@ -58,16 +59,22 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
5859
// 4. Allocate cut pointer by increasing: n_columns + 1
5960
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
6061
// 5. Allocate cuts: assuming rows is greater than bins: n_columns * limit_size
61-
total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry);
62-
// 6. Deallocate copied entries by reducing: bytes_per_element * shape.
62+
total += cuts_bytes;
63+
// 6. Install the first batch summary into the resident sketch while the temporary pruned
64+
// summary is still live.
65+
total += cuts_bytes;
66+
// 7. Deallocate copied entries by reducing: bytes_per_element * shape.
6367
peak = std::max(peak, total);
6468
total -= (BytesPerElement(with_weights) * num_rows * num_columns) / 2;
65-
// 7. Deallocate column size scan.
69+
// 8. Deallocate the temporary pruned batch summary after merge/prune commit.
70+
peak = std::max(peak, total);
71+
total -= cuts_bytes;
72+
// 9. Deallocate column size scan.
6673
peak = std::max(peak, total);
6774
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
68-
// 8. Deallocate cut size scan.
75+
// 10. Deallocate cut size scan.
6976
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
70-
// 9. Allocate final cut values and cut ptrs: std::min(rows, bins + 1) * n_columns +
77+
// 11. Allocate final cut values and cut ptrs: std::min(rows, bins + 1) * n_columns +
7178
// n_columns + 1
7279
total += std::min(num_rows, num_bins) * num_columns * sizeof(float);
7380
total +=
@@ -269,8 +276,9 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
269276
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
270277

271278
// Add cuts into sketches
279+
auto approx_n_samples = std::max<bst_idx_t>(1, (end - begin + info.num_col_ - 1) / info.num_col_);
272280
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
273-
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
281+
h_cuts_ptr.back(), approx_n_samples, dh::ToSpan(entry_weight));
274282

275283
sorted_entries.clear();
276284
sorted_entries.shrink_to_fit();

src/common/hist_util.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ inline HistogramCuts DeviceSketch(
278278
template <typename AdapterBatch>
279279
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
280280
size_t n_features, size_t begin, size_t end, float missing,
281-
SketchContainer* sketch_container, int num_cuts) {
281+
SketchContainer* sketch_container, int num_cuts,
282+
bst_idx_t approx_n_samples) {
282283
// Copy current subset of valid elements into temporary storage and sort
283284
dh::device_vector<Entry> sorted_entries;
284285
dh::caching_device_vector<size_t> column_sizes_scan;
@@ -303,7 +304,7 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
303304
auto const& h_cuts_ptr = cuts_ptr.HostVector();
304305
// Extract the cuts from all columns concurrently
305306
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
306-
h_cuts_ptr.back());
307+
h_cuts_ptr.back(), approx_n_samples);
307308

308309
sorted_entries.clear();
309310
sorted_entries.shrink_to_fit();
@@ -313,7 +314,7 @@ template <typename Batch>
313314
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
314315
int num_cuts_per_feature, bool is_ranking, float missing,
315316
size_t columns, size_t begin, size_t end,
316-
SketchContainer* sketch_container) {
317+
SketchContainer* sketch_container, bst_idx_t approx_n_samples) {
317318
curt::SetDevice(ctx->Ordinal());
318319
info.weights_.SetDevice(ctx->Device());
319320
auto weights = info.weights_.ConstDeviceSpan();
@@ -379,7 +380,7 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
379380

380381
// Extract cuts
381382
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
382-
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
383+
h_cuts_ptr.back(), approx_n_samples, dh::ToSpan(temp_weights));
383384
sorted_entries.clear();
384385
sorted_entries.shrink_to_fit();
385386
}
@@ -431,10 +432,10 @@ void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, Me
431432
if (weighted) {
432433
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
433434
HostSketchContainer::UseGroup(info), missing, num_cols, begin,
434-
end, sketch_container);
435+
end, sketch_container, approx_n_samples);
435436
} else {
436437
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
437-
num_cuts_per_feature);
438+
num_cuts_per_feature, approx_n_samples);
438439
}
439440
begin += sketch_batch_num_elements;
440441
}

src/common/quantile.cu

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -301,21 +301,17 @@ void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
301301
});
302302
}
303303

304+
// Convert one sorted batch into a temporary pruned summary in `prune_buffer_`, normalize
305+
// duplicated raw values in place, then merge that summary into the resident sketch in
306+
// `entries_`. Out-of-place merge/prune results use `entries_tmp_` as scratch before being
307+
// committed back into `entries_`.
304308
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
305-
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
309+
common::Span<OffsetT> cuts_ptr, size_t total_cuts,
310+
bst_idx_t n_rows_in_batch, Span<float> weights) {
306311
curt::SetDevice(ctx->Ordinal());
307-
auto &current = this->entries_;
308-
auto &columns_ptr_out = this->columns_ptr_;
309-
Span<SketchEntry> out;
310-
dh::device_vector<SketchEntry> cuts;
311-
bool first_window = current.empty();
312-
if (!first_window) {
313-
cuts.resize(total_cuts);
314-
out = dh::ToSpan(cuts);
315-
} else {
316-
current.resize(total_cuts);
317-
out = dh::ToSpan(current);
318-
}
312+
rows_seen_ += n_rows_in_batch;
313+
this->prune_buffer_.resize(total_cuts);
314+
auto out = dh::ToSpan(this->prune_buffer_);
319315
auto ft = this->feature_types_.ConstDeviceSpan();
320316
if (weights.empty()) {
321317
auto to_sketch_entry = [] __device__(size_t sample_idx, Span<Entry const> const &column,
@@ -340,19 +336,13 @@ void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<s
340336
PruneImpl<Entry>(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry);
341337
}
342338
auto n_uniques = this->ScanInput(ctx, out, cuts_ptr);
343-
344-
if (!first_window) {
345-
CHECK_EQ(columns_ptr_out.Size(), cuts_ptr.size());
346-
out = out.subspan(0, n_uniques);
347-
this->Merge(ctx, cuts_ptr, out);
348-
} else {
349-
current.resize(n_uniques);
350-
columns_ptr_out.SetDevice(ctx->Device());
351-
columns_ptr_out.Resize(cuts_ptr.size());
352-
353-
auto d_cuts_ptr = columns_ptr_out.DeviceSpan();
354-
CopyTo(d_cuts_ptr, cuts_ptr);
339+
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
340+
if (n_uniques == 0) {
341+
return;
355342
}
343+
this->Merge(ctx, cuts_ptr, out.subspan(0, n_uniques));
344+
auto intermediate_num_cuts = static_cast<bst_idx_t>(this->IntermediateNumCuts());
345+
this->Prune(ctx, intermediate_num_cuts);
356346
}
357347

358348
size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
@@ -404,6 +394,11 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
404394
auto &columns_ptr_tmp = this->columns_ptr_tmp_;
405395
auto const &feature_types = this->feature_types_;
406396

397+
if (entries.size() <= to * num_columns_) {
398+
timer_.Stop(__func__);
399+
return;
400+
}
401+
407402
OffsetT to_total = 0;
408403
auto &h_columns_ptr = columns_ptr_tmp.HostVector();
409404
h_columns_ptr[0] = to_total;

src/common/quantile.cuh

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <thrust/logical.h> // for any_of
88

9+
#include <algorithm>
910
#include <cstddef> // for size_t
1011
#include <functional> // for equal_to
1112

@@ -52,17 +53,25 @@ class SketchContainer {
5253
// The container is just a CSC matrix plus scratch storage for out-of-place transforms.
5354
dh::device_vector<SketchEntry> entries_;
5455
dh::device_vector<SketchEntry> entries_tmp_;
56+
dh::device_vector<SketchEntry> prune_buffer_;
5557
HostDeviceVector<OffsetT> columns_ptr_;
5658
HostDeviceVector<OffsetT> columns_ptr_tmp_;
5759

5860
bool has_categorical_{false};
61+
std::size_t rows_seen_{0};
5962

6063
void SetCurrentColumns(Span<OffsetT const> columns_ptr);
6164
void CommitScratch(std::size_t n_entries) {
6265
entries_.swap(entries_tmp_);
6366
columns_ptr_.Copy(columns_ptr_tmp_);
6467
entries_.resize(n_entries);
6568
}
69+
[[nodiscard]] std::size_t IntermediateNumCuts() const {
70+
auto const base = static_cast<std::size_t>(num_bins_) * kFactor;
71+
auto const eps = 1.0 / static_cast<double>(base);
72+
auto const per_feature = WQSketch::LimitSizeLevel(std::max<std::size_t>(1, rows_seen_), eps);
73+
return per_feature * num_columns_;
74+
}
6675

6776
// Get the span of one column.
6877
Span<SketchEntry> Column(bst_feature_t i) {
@@ -109,15 +118,18 @@ class SketchContainer {
109118
*/
110119
[[nodiscard]] std::size_t MemCapacityBytes() const {
111120
auto constexpr kE = sizeof(typename decltype(this->entries_)::value_type);
112-
auto n_bytes = (this->entries_.capacity() + this->entries_tmp_.capacity()) * kE;
121+
auto n_bytes =
122+
(this->entries_.capacity() + this->entries_tmp_.capacity() + this->prune_buffer_.capacity()) *
123+
kE;
113124
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_tmp_.Size()) * sizeof(OffsetT);
114125
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
115126

116127
return n_bytes;
117128
}
118129
[[nodiscard]] std::size_t MemCostBytes() const {
119130
auto constexpr kE = sizeof(typename decltype(this->entries_)::value_type);
120-
auto n_bytes = (this->entries_.size() + this->entries_tmp_.size()) * kE;
131+
auto n_bytes =
132+
(this->entries_.size() + this->entries_tmp_.size() + this->prune_buffer_.size()) * kE;
121133
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_tmp_.Size()) * sizeof(OffsetT);
122134
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
123135

@@ -140,7 +152,8 @@ class SketchContainer {
140152
* \param weights (optional) data weights.
141153
*/
142154
void Push(Context const* ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
143-
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights = {});
155+
common::Span<OffsetT> cuts_ptr, size_t total_cuts, bst_idx_t n_rows_in_batch,
156+
Span<float> weights = {});
144157
/**
145158
* @brief Prune the quantile structure.
146159
*

tests/cpp/common/test_hist_util.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ inline void TestRank(const std::vector<float>& column_cuts, const std::vector<fl
111111
j++;
112112
}
113113
double expected_rank = ((i + 1) * total_weight) / column_cuts.size();
114-
double acceptable_error = std::max(2.9, total_weight * eps);
114+
// For small sketches, a purely relative tolerance can be tighter than one bin's
115+
// expected mass. Use the larger of the relative tolerance and the average per-cut
116+
// mass instead of a hard-coded floor.
117+
double acceptable_error =
118+
std::max(total_weight * eps, total_weight / static_cast<double>(column_cuts.size()));
115119
EXPECT_LE(std::abs(expected_rank - sum_weight), acceptable_error);
116120
}
117121
}

tests/cpp/common/test_quantile.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ TEST(GPUQuantile, Basic) {
4545
dh::device_vector<bst_idx_t> cuts_ptr(kCols + 1);
4646
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
4747
// Push empty
48-
sketch.Push(&ctx, dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0);
48+
sketch.Push(&ctx, dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0, 0);
4949
ASSERT_EQ(sketch.Data().size(), 0);
5050
}
5151

@@ -332,9 +332,9 @@ TEST(GPUQuantile, MergeCategorical) {
332332
dh::device_vector<size_t> cuts_ptr_1{0, 5, 8};
333333

334334
sketch_0.Push(&ctx, dh::ToSpan(d_entries_0), dh::ToSpan(columns_ptr_0), dh::ToSpan(cuts_ptr_0),
335-
entries_0.size(), {});
335+
entries_0.size(), 5, {});
336336
sketch_1.Push(&ctx, dh::ToSpan(d_entries_1), dh::ToSpan(columns_ptr_1), dh::ToSpan(cuts_ptr_1),
337-
entries_1.size(), {});
337+
entries_1.size(), 5, {});
338338

339339
sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data());
340340
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr());
@@ -639,7 +639,7 @@ TEST(GPUQuantile, Push) {
639639
HostDeviceVector<FeatureType> ft;
640640
SketchContainer sketch(ft, n_bins, kCols, ctx.Device());
641641
sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows,
642-
{});
642+
kRows, {});
643643

644644
auto sketch_data = sketch.Data();
645645

@@ -690,7 +690,7 @@ TEST(GPUQuantile, MultiColPush) {
690690
dh::device_vector<size_t> cuts_ptr(columns_ptr);
691691

692692
sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr),
693-
kRows * kCols, {});
693+
kRows * kCols, kRows, {});
694694

695695
auto sketch_data = sketch.Data();
696696
ASSERT_EQ(sketch_data.size(), kCols * 2);

0 commit comments

Comments
 (0)