Skip to content

Commit 8069896

Browse files
committed
Prune GPU sketch after each merge
1 parent 5a754f4 commit 8069896

File tree

1 file changed

+6
-25
lines changed

1 file changed

+6
-25
lines changed

src/common/quantile.cu

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -304,18 +304,8 @@ void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
304304
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
305305
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
306306
curt::SetDevice(ctx->Ordinal());
307-
auto &current = this->entries_;
308-
auto &columns_ptr_out = this->columns_ptr_;
309-
Span<SketchEntry> out;
310-
dh::device_vector<SketchEntry> cuts;
311-
bool first_window = current.empty();
312-
if (!first_window) {
313-
cuts.resize(total_cuts);
314-
out = dh::ToSpan(cuts);
315-
} else {
316-
current.resize(total_cuts);
317-
out = dh::ToSpan(current);
318-
}
307+
dh::device_vector<SketchEntry> cuts(total_cuts);
308+
auto out = dh::ToSpan(cuts);
319309
auto ft = this->feature_types_.ConstDeviceSpan();
320310
if (weights.empty()) {
321311
auto to_sketch_entry = [] __device__(size_t sample_idx, Span<Entry const> const &column,
@@ -340,19 +330,10 @@ void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<s
340330
PruneImpl<Entry>(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry);
341331
}
342332
auto n_uniques = this->ScanInput(ctx, out, cuts_ptr);
343-
344-
if (!first_window) {
345-
CHECK_EQ(columns_ptr_out.Size(), cuts_ptr.size());
346-
out = out.subspan(0, n_uniques);
347-
this->Merge(ctx, cuts_ptr, out);
348-
} else {
349-
current.resize(n_uniques);
350-
columns_ptr_out.SetDevice(ctx->Device());
351-
columns_ptr_out.Resize(cuts_ptr.size());
352-
353-
auto d_cuts_ptr = columns_ptr_out.DeviceSpan();
354-
CopyTo(d_cuts_ptr, cuts_ptr);
355-
}
333+
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
334+
this->Merge(ctx, cuts_ptr, out.subspan(0, n_uniques));
335+
auto intermediate_num_cuts = static_cast<bst_idx_t>(num_bins_ * kFactor);
336+
this->Prune(ctx, intermediate_num_cuts);
356337
}
357338

358339
size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,

0 commit comments

Comments
 (0)