Skip to content

Commit d1f0040

Browse files
committed
Deduplicate merged GPU sketch entries
1 parent b2f15e6 commit d1f0040

File tree

2 files changed

+113
-20
lines changed

2 files changed

+113
-20
lines changed

src/common/quantile.cu

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -498,24 +498,17 @@ void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_colum
498498

499499
timer_.Start(__func__);
500500
auto normalize_merged = [&] {
501-
if (this->HasCategorical()) {
502-
// Numerical summaries are normalized during prune. Categorical features can still
503-
// produce repeated category values, so compact those here before exposing the sketch.
504-
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
505-
auto d_column_scan = columns_ptr.DeviceSpan();
506-
auto merged_entries = dh::ToSpan(entries);
507-
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
508-
scan_out.SetDevice(ctx->Device());
509-
auto n_uniques = dh::SegmentedUnique(
510-
ctx->CUDACtx()->CTP(), d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
511-
merged_entries.data(), merged_entries.data() + merged_entries.size(),
512-
scan_out.DevicePointer(), merged_entries.data(), detail::SketchUnique{},
513-
[d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
514-
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
515-
});
516-
columns_ptr.Copy(scan_out);
517-
entries.resize(n_uniques);
518-
}
501+
// Merge can leave adjacent duplicate values in both numerical and categorical summaries.
502+
auto d_column_scan = columns_ptr.DeviceSpan();
503+
auto merged_entries = dh::ToSpan(entries);
504+
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
505+
scan_out.SetDevice(ctx->Device());
506+
auto n_uniques = dh::SegmentedUnique(
507+
ctx->CUDACtx()->CTP(), d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
508+
merged_entries.data(), merged_entries.data() + merged_entries.size(),
509+
scan_out.DevicePointer(), merged_entries.data(), detail::SketchUnique{});
510+
columns_ptr.Copy(scan_out);
511+
entries.resize(n_uniques);
519512
this->FixError();
520513
};
521514
if (entries.empty()) {

tests/cpp/common/test_quantile.cu

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ auto MakeFullRowSplitDMatrix(std::size_t rows_per_worker, std::size_t cols, std:
5050
}
5151
return GetDMatrixFromData(full_data, rows_per_worker * world, cols);
5252
}
53+
54+
auto MakeHostSummary(std::vector<std::pair<float, float>> const& items)
55+
-> common::WQSummaryContainer {
56+
common::WQSummaryContainer summary;
57+
summary.Reserve(items.size());
58+
summary.SetFromSorted(items);
59+
return summary;
60+
}
61+
62+
auto CopySummaryEntries(common::WQSummaryContainer const& summary)
63+
-> std::vector<common::SketchEntry> {
64+
auto entries = summary.Entries();
65+
return {entries.cbegin(), entries.cend()};
66+
}
5367
} // namespace
5468

5569
namespace common {
@@ -251,14 +265,18 @@ TEST(GPUQuantile, MergeBasic) {
251265
auto columns_ptr = sketch_0.ColumnsPtr();
252266
std::vector<bst_idx_t> h_columns_ptr(columns_ptr.size());
253267
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
254-
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
268+
ASSERT_LE(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
255269

256270
std::vector<SketchEntry> h_data(sketch_0.Data().size());
257271
dh::CopyDeviceSpanToVector(&h_data, sketch_0.Data());
258272
for (size_t i = 1; i < h_columns_ptr.size(); ++i) {
259273
auto begin = h_columns_ptr[i - 1];
260274
auto column = Span<SketchEntry>{h_data}.subspan(begin, h_columns_ptr[i] - begin);
261275
ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{}));
276+
ASSERT_TRUE(std::adjacent_find(column.begin(), column.end(),
277+
[](SketchEntry const& l, SketchEntry const& r) {
278+
return l.value == r.value;
279+
}) == column.end());
262280
}
263281
});
264282
}
@@ -309,14 +327,18 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
309327
auto columns_ptr = sketch_0.ColumnsPtr();
310328
std::vector<bst_idx_t> h_columns_ptr(columns_ptr.size());
311329
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
312-
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
330+
ASSERT_LE(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
313331

314332
std::vector<SketchEntry> h_data(sketch_0.Data().size());
315333
dh::CopyDeviceSpanToVector(&h_data, sketch_0.Data());
316334
for (size_t i = 1; i < h_columns_ptr.size(); ++i) {
317335
auto begin = h_columns_ptr[i - 1];
318336
auto column = Span<SketchEntry>{h_data}.subspan(begin, h_columns_ptr[i] - begin);
319337
ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{}));
338+
ASSERT_TRUE(std::adjacent_find(column.begin(), column.end(),
339+
[](SketchEntry const& l, SketchEntry const& r) {
340+
return l.value == r.value;
341+
}) == column.end());
320342
}
321343
}
322344

@@ -370,6 +392,84 @@ TEST(GPUQuantile, MergeCategorical) {
370392
}) == cat_column.end());
371393
}
372394

395+
TEST(GPUQuantile, MergeSameValue) {
396+
auto ctx = MakeCUDACtx(0);
397+
constexpr bst_feature_t kCols = 1;
398+
bst_bin_t n_bins = 16;
399+
400+
HostDeviceVector<FeatureType> ft;
401+
SketchContainer sketch_0(ft, n_bins, kCols, ctx.Device());
402+
SketchContainer sketch_1(ft, n_bins, kCols, ctx.Device());
403+
404+
std::vector<Entry> entries_0{{0, 0.5f}};
405+
std::vector<Entry> entries_1{{0, 0.5f}};
406+
dh::device_vector<Entry> d_entries_0{entries_0};
407+
dh::device_vector<Entry> d_entries_1{entries_1};
408+
dh::device_vector<size_t> columns_ptr{0, 1};
409+
dh::device_vector<size_t> cuts_ptr{0, 1};
410+
411+
sketch_0.Push(&ctx, dh::ToSpan(d_entries_0), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr), 1, 1,
412+
{});
413+
sketch_1.Push(&ctx, dh::ToSpan(d_entries_1), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr), 1, 1,
414+
{});
415+
416+
sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data());
417+
418+
std::vector<bst_idx_t> h_columns_ptr(sketch_0.ColumnsPtr().size());
419+
dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch_0.ColumnsPtr());
420+
std::vector<SketchEntry> h_data(sketch_0.Data().size());
421+
dh::CopyDeviceSpanToVector(&h_data, sketch_0.Data());
422+
423+
ASSERT_EQ(h_columns_ptr.back(), 1);
424+
ASSERT_EQ(h_data.size(), 1);
425+
EXPECT_FLOAT_EQ(h_data.front().value, 0.5f);
426+
EXPECT_FLOAT_EQ(h_data.front().rmin, 0.0f);
427+
EXPECT_FLOAT_EQ(h_data.front().wmin, 2.0f);
428+
EXPECT_FLOAT_EQ(h_data.front().rmax, 2.0f);
429+
}
430+
431+
TEST(GPUQuantile, MergeMatchesCpuCombine) {
432+
auto ctx = MakeCUDACtx(0);
433+
constexpr bst_feature_t kCols = 1;
434+
bst_bin_t n_bins = 16;
435+
436+
auto lhs = MakeHostSummary({{0.1f, 1.0f}, {0.3f, 2.0f}, {0.5f, 1.0f}});
437+
auto rhs = MakeHostSummary({{0.3f, 1.5f}, {0.4f, 1.0f}, {0.5f, 0.5f}});
438+
439+
common::WQSummaryContainer expected;
440+
expected.Reserve(lhs.Size() + rhs.Size());
441+
expected.CopyFrom(lhs);
442+
expected.SetCombine(rhs);
443+
444+
auto lhs_entries = CopySummaryEntries(lhs);
445+
auto rhs_entries = CopySummaryEntries(rhs);
446+
447+
dh::device_vector<SketchEntry> d_lhs{lhs_entries};
448+
dh::device_vector<SketchEntry> d_rhs{rhs_entries};
449+
dh::device_vector<size_t> lhs_ptr{0, lhs.Size()};
450+
dh::device_vector<size_t> rhs_ptr{0, rhs.Size()};
451+
452+
HostDeviceVector<FeatureType> ft;
453+
SketchContainer sketch(ft, n_bins, kCols, ctx.Device());
454+
sketch.Merge(&ctx, dh::ToSpan(lhs_ptr), dh::ToSpan(d_lhs));
455+
sketch.Merge(&ctx, dh::ToSpan(rhs_ptr), dh::ToSpan(d_rhs));
456+
457+
std::vector<bst_idx_t> h_columns_ptr(sketch.ColumnsPtr().size());
458+
dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch.ColumnsPtr());
459+
auto h_data = std::vector<SketchEntry>(sketch.Data().size());
460+
dh::CopyDeviceSpanToVector(&h_data, sketch.Data());
461+
462+
ASSERT_EQ(h_columns_ptr.back(), expected.Size());
463+
auto expected_entries = expected.Entries();
464+
ASSERT_EQ(h_data.size(), expected_entries.size());
465+
for (std::size_t i = 0; i < h_data.size(); ++i) {
466+
EXPECT_FLOAT_EQ(h_data[i].value, expected_entries[i].value);
467+
EXPECT_FLOAT_EQ(h_data[i].rmin, expected_entries[i].rmin);
468+
EXPECT_FLOAT_EQ(h_data[i].rmax, expected_entries[i].rmax);
469+
EXPECT_FLOAT_EQ(h_data[i].wmin, expected_entries[i].wmin);
470+
}
471+
}
472+
373473
TEST(GPUQuantile, MultiMerge) {
374474
constexpr size_t kRows = 20, kCols = 1;
375475
int32_t world = 2;

0 commit comments

Comments
 (0)