Skip to content

Commit 45a8bb7

Browse files
committed
Replace GPU quantile double buffering
1 parent ef7e924 commit 45a8bb7

File tree

2 files changed

+46
-52
lines changed

2 files changed

+46
-52
lines changed

src/common/quantile.cu

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ XGBOOST_DEVICE thrust::tuple<uint64_t, uint64_t> MergePartition(Span<SketchEntry
183183
return thrust::make_tuple(a_ind, k - a_ind);
184184
}
185185

186+
void SketchContainer::SetCurrentColumns(Span<OffsetT const> columns_ptr) {
187+
columns_ptr_.Resize(columns_ptr.size());
188+
CopyTo(columns_ptr_.DeviceSpan(), columns_ptr);
189+
}
190+
186191
// Merge d_x and d_y into out. Because the final output depends on predicate (which
187192
// summary does the output element come from) result by definition of merged rank. So we
188193
// compute the partition for each output directly and customize the standard merge
@@ -373,7 +378,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
373378
return l;
374379
});
375380

376-
auto d_columns_ptr_out = columns_ptr_b_.DeviceSpan();
381+
auto d_columns_ptr_out = this->ScratchColumns();
377382
// thrust unique_by_key preserves the first element.
378383
auto n_uniques =
379384
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), d_columns_ptr_in.data(),
@@ -391,7 +396,7 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
391396
curt::SetDevice(ctx->Ordinal());
392397

393398
OffsetT to_total = 0;
394-
auto &h_columns_ptr = columns_ptr_b_.HostVector();
399+
auto &h_columns_ptr = columns_ptr_tmp_.HostVector();
395400
h_columns_ptr[0] = to_total;
396401
auto const &h_feature_types = feature_types_.ConstHostSpan();
397402
for (bst_feature_t i = 0; i < num_columns_; ++i) {
@@ -403,16 +408,16 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
403408
to_total += length;
404409
h_columns_ptr[i + 1] = to_total;
405410
}
406-
this->Other().resize(to_total);
411+
this->Scratch().resize(to_total);
407412

408413
auto d_columns_ptr_in = this->columns_ptr_.ConstDeviceSpan();
409-
auto d_columns_ptr_out = columns_ptr_b_.ConstDeviceSpan();
410-
auto out = dh::ToSpan(this->Other());
414+
auto d_columns_ptr_out = columns_ptr_tmp_.ConstDeviceSpan();
415+
auto out = dh::ToSpan(this->Scratch());
411416
auto in = dh::ToSpan(this->Current());
412417
auto ft = this->feature_types_.ConstDeviceSpan();
413418
dh::device_vector<size_t> selected_idx(out.size());
414419
auto d_selected_idx = dh::ToSpan(selected_idx);
415-
HostDeviceVector<OffsetT> selected_columns_ptr(columns_ptr_b_.Size());
420+
HostDeviceVector<OffsetT> selected_columns_ptr(columns_ptr_tmp_.Size());
416421
selected_columns_ptr.SetDevice(ctx->Device());
417422
auto entry_from_index = [=] __device__(size_t abs_idx) {
418423
return in[abs_idx];
@@ -427,8 +432,8 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
427432
d_selected_idx.data(), thrust::equal_to<size_t>{});
428433
GatherPruneEntries(Span<size_t const>{d_selected_idx.data(), n_selected}, out, entry_from_index,
429434
stream);
435+
this->entries_.swap(this->entries_tmp_);
430436
this->columns_ptr_.Copy(selected_columns_ptr);
431-
this->Alternate();
432437
this->Current().resize(n_selected);
433438
auto d_column_scan = this->columns_ptr_.DeviceSpan();
434439
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
@@ -489,21 +494,20 @@ void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_colum
489494

490495
std::size_t new_size = this->Current().size() + that.size();
491496
try {
492-
this->Other().resize(new_size);
497+
this->Scratch().resize(new_size);
493498
} catch (dmlc::Error const &) {
494499
// Retry
495-
this->Other().clear();
496-
this->Other().shrink_to_fit();
497-
this->Other().resize(new_size);
500+
this->Scratch().clear();
501+
this->Scratch().shrink_to_fit();
502+
this->Scratch().resize(new_size);
498503
}
499504

500505
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
501506

502507
MergeImpl(ctx, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
503-
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
504-
this->columns_ptr_.Copy(columns_ptr_b_);
508+
dh::ToSpan(this->Scratch()), columns_ptr_tmp_.DeviceSpan());
509+
this->CommitScratch(new_size);
505510
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
506-
this->Alternate();
507511
normalize_merged();
508512
timer_.Stop(__func__);
509513
}

src/common/quantile.cuh

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,38 +49,28 @@ class SketchContainer {
4949
bst_feature_t num_columns_;
5050
int32_t num_bins_;
5151

52-
// Double buffer as neither prune nor merge can be performed inplace.
53-
dh::device_vector<SketchEntry> entries_a_;
54-
dh::device_vector<SketchEntry> entries_b_;
55-
bool current_buffer_ {true};
56-
// The container is just a CSC matrix.
52+
// The container is just a CSC matrix plus scratch storage for out-of-place transforms.
53+
dh::device_vector<SketchEntry> entries_;
54+
dh::device_vector<SketchEntry> entries_tmp_;
5755
HostDeviceVector<OffsetT> columns_ptr_;
58-
HostDeviceVector<OffsetT> columns_ptr_b_;
56+
HostDeviceVector<OffsetT> columns_ptr_tmp_;
5957

6058
bool has_categorical_{false};
6159

62-
dh::device_vector<SketchEntry>& Current() {
63-
if (current_buffer_) {
64-
return entries_a_;
65-
} else {
66-
return entries_b_;
67-
}
68-
}
69-
dh::device_vector<SketchEntry>& Other() {
70-
if (!current_buffer_) {
71-
return entries_a_;
72-
} else {
73-
return entries_b_;
74-
}
75-
}
76-
dh::device_vector<SketchEntry> const& Current() const {
77-
return const_cast<SketchContainer*>(this)->Current();
78-
}
79-
dh::device_vector<SketchEntry> const& Other() const {
80-
return const_cast<SketchContainer*>(this)->Other();
81-
}
82-
void Alternate() {
83-
current_buffer_ = !current_buffer_;
60+
dh::device_vector<SketchEntry>& Current() { return entries_; }
61+
dh::device_vector<SketchEntry>& Scratch() { return entries_tmp_; }
62+
dh::device_vector<SketchEntry> const& Current() const { return entries_; }
63+
dh::device_vector<SketchEntry> const& Scratch() const { return entries_tmp_; }
64+
Span<OffsetT const> CurrentColumnsHost() const { return columns_ptr_.ConstHostSpan(); }
65+
std::vector<OffsetT>& ScratchColumnsHost() { return columns_ptr_tmp_.HostVector(); }
66+
Span<OffsetT const> CurrentColumns() const { return columns_ptr_.ConstDeviceSpan(); }
67+
Span<OffsetT> ScratchColumns() { return columns_ptr_tmp_.DeviceSpan(); }
68+
Span<OffsetT const> ScratchColumns() const { return columns_ptr_tmp_.ConstDeviceSpan(); }
69+
void SetCurrentColumns(Span<OffsetT const> columns_ptr);
70+
void CommitScratch(std::size_t n_entries) {
71+
entries_.swap(entries_tmp_);
72+
columns_ptr_.Copy(columns_ptr_tmp_);
73+
entries_.resize(n_entries);
8474
}
8575

8676
// Get the span of one column.
@@ -105,8 +95,8 @@ class SketchContainer {
10595
// Initialize Sketches for this dmatrix
10696
this->columns_ptr_.SetDevice(device);
10797
this->columns_ptr_.Resize(num_columns + 1, 0);
108-
this->columns_ptr_b_.SetDevice(device);
109-
this->columns_ptr_b_.Resize(num_columns + 1, 0);
98+
this->columns_ptr_tmp_.SetDevice(device);
99+
this->columns_ptr_tmp_.Resize(num_columns + 1, 0);
110100

111101
this->feature_types_.Resize(feature_types.Size());
112102
this->feature_types_.Copy(feature_types);
@@ -127,17 +117,17 @@ class SketchContainer {
127117
* @brief Calculate the memory cost of the container.
128118
*/
129119
[[nodiscard]] std::size_t MemCapacityBytes() const {
130-
auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type);
131-
auto n_bytes = (this->entries_a_.capacity() + this->entries_b_.capacity()) * kE;
132-
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT);
120+
auto constexpr kE = sizeof(typename decltype(this->entries_)::value_type);
121+
auto n_bytes = (this->entries_.capacity() + this->entries_tmp_.capacity()) * kE;
122+
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_tmp_.Size()) * sizeof(OffsetT);
133123
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
134124

135125
return n_bytes;
136126
}
137127
[[nodiscard]] std::size_t MemCostBytes() const {
138-
auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type);
139-
auto n_bytes = (this->entries_a_.size() + this->entries_b_.size()) * kE;
140-
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT);
128+
auto constexpr kE = sizeof(typename decltype(this->entries_)::value_type);
129+
auto n_bytes = (this->entries_.size() + this->entries_tmp_.size()) * kE;
130+
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_tmp_.Size()) * sizeof(OffsetT);
141131
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
142132

143133
return n_bytes;
@@ -181,8 +171,8 @@ class SketchContainer {
181171
*/
182172
void ShrinkToFit() {
183173
this->Current().shrink_to_fit();
184-
this->Other().clear();
185-
this->Other().shrink_to_fit();
174+
this->Scratch().clear();
175+
this->Scratch().shrink_to_fit();
186176
LOG(DEBUG) << "Quantile memory cost:" << common::HumanMemUnit(this->MemCapacityBytes());
187177
}
188178

0 commit comments

Comments
 (0)