@@ -50,6 +50,20 @@ auto MakeFullRowSplitDMatrix(std::size_t rows_per_worker, std::size_t cols, std:
5050 }
5151 return GetDMatrixFromData (full_data, rows_per_worker * world, cols);
5252}
53+
54+ auto MakeHostSummary (std::vector<std::pair<float , float >> const & items)
55+ -> common::WQSummaryContainer {
56+ common::WQSummaryContainer summary;
57+ summary.Reserve (items.size ());
58+ summary.SetFromSorted (items);
59+ return summary;
60+ }
61+
62+ auto CopySummaryEntries (common::WQSummaryContainer const & summary)
63+ -> std::vector<common::SketchEntry> {
64+ auto entries = summary.Entries ();
65+ return {entries.cbegin (), entries.cend ()};
66+ }
5367} // namespace
5468
5569namespace common {
@@ -251,14 +265,18 @@ TEST(GPUQuantile, MergeBasic) {
251265 auto columns_ptr = sketch_0.ColumnsPtr ();
252266 std::vector<bst_idx_t > h_columns_ptr (columns_ptr.size ());
253267 dh::CopyDeviceSpanToVector (&h_columns_ptr, columns_ptr);
254- ASSERT_EQ (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
268+ ASSERT_LE (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
255269
256270 std::vector<SketchEntry> h_data (sketch_0.Data ().size ());
257271 dh::CopyDeviceSpanToVector (&h_data, sketch_0.Data ());
258272 for (size_t i = 1 ; i < h_columns_ptr.size (); ++i) {
259273 auto begin = h_columns_ptr[i - 1 ];
260274 auto column = Span<SketchEntry>{h_data}.subspan (begin, h_columns_ptr[i] - begin);
261275 ASSERT_TRUE (std::is_sorted (column.begin (), column.end (), IsSorted{}));
276+ ASSERT_TRUE (std::adjacent_find (column.begin (), column.end (),
277+ [](SketchEntry const & l, SketchEntry const & r) {
278+ return l.value == r.value ;
279+ }) == column.end ());
262280 }
263281 });
264282}
@@ -309,14 +327,18 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
309327 auto columns_ptr = sketch_0.ColumnsPtr ();
310328 std::vector<bst_idx_t > h_columns_ptr (columns_ptr.size ());
311329 dh::CopyDeviceSpanToVector (&h_columns_ptr, columns_ptr);
312- ASSERT_EQ (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
330+ ASSERT_LE (h_columns_ptr.back (), sketch_1.Data ().size () + size_before_merge);
313331
314332 std::vector<SketchEntry> h_data (sketch_0.Data ().size ());
315333 dh::CopyDeviceSpanToVector (&h_data, sketch_0.Data ());
316334 for (size_t i = 1 ; i < h_columns_ptr.size (); ++i) {
317335 auto begin = h_columns_ptr[i - 1 ];
318336 auto column = Span<SketchEntry>{h_data}.subspan (begin, h_columns_ptr[i] - begin);
319337 ASSERT_TRUE (std::is_sorted (column.begin (), column.end (), IsSorted{}));
338+ ASSERT_TRUE (std::adjacent_find (column.begin (), column.end (),
339+ [](SketchEntry const & l, SketchEntry const & r) {
340+ return l.value == r.value ;
341+ }) == column.end ());
320342 }
321343}
322344
@@ -370,6 +392,84 @@ TEST(GPUQuantile, MergeCategorical) {
370392 }) == cat_column.end ());
371393}
372394
395+ TEST (GPUQuantile, MergeSameValue) {
396+ auto ctx = MakeCUDACtx (0 );
397+ constexpr bst_feature_t kCols = 1 ;
398+ bst_bin_t n_bins = 16 ;
399+
400+ HostDeviceVector<FeatureType> ft;
401+ SketchContainer sketch_0 (ft, n_bins, kCols , ctx.Device ());
402+ SketchContainer sketch_1 (ft, n_bins, kCols , ctx.Device ());
403+
404+ std::vector<Entry> entries_0{{0 , 0 .5f }};
405+ std::vector<Entry> entries_1{{0 , 0 .5f }};
406+ dh::device_vector<Entry> d_entries_0{entries_0};
407+ dh::device_vector<Entry> d_entries_1{entries_1};
408+ dh::device_vector<size_t > columns_ptr{0 , 1 };
409+ dh::device_vector<size_t > cuts_ptr{0 , 1 };
410+
411+ sketch_0.Push (&ctx, dh::ToSpan (d_entries_0), dh::ToSpan (columns_ptr), dh::ToSpan (cuts_ptr), 1 , 1 ,
412+ {});
413+ sketch_1.Push (&ctx, dh::ToSpan (d_entries_1), dh::ToSpan (columns_ptr), dh::ToSpan (cuts_ptr), 1 , 1 ,
414+ {});
415+
416+ sketch_0.Merge (&ctx, sketch_1.ColumnsPtr (), sketch_1.Data ());
417+
418+ std::vector<bst_idx_t > h_columns_ptr (sketch_0.ColumnsPtr ().size ());
419+ dh::CopyDeviceSpanToVector (&h_columns_ptr, sketch_0.ColumnsPtr ());
420+ std::vector<SketchEntry> h_data (sketch_0.Data ().size ());
421+ dh::CopyDeviceSpanToVector (&h_data, sketch_0.Data ());
422+
423+ ASSERT_EQ (h_columns_ptr.back (), 1 );
424+ ASSERT_EQ (h_data.size (), 1 );
425+ EXPECT_FLOAT_EQ (h_data.front ().value , 0 .5f );
426+ EXPECT_FLOAT_EQ (h_data.front ().rmin , 0 .0f );
427+ EXPECT_FLOAT_EQ (h_data.front ().wmin , 2 .0f );
428+ EXPECT_FLOAT_EQ (h_data.front ().rmax , 2 .0f );
429+ }
430+
431+ TEST (GPUQuantile, MergeMatchesCpuCombine) {
432+ auto ctx = MakeCUDACtx (0 );
433+ constexpr bst_feature_t kCols = 1 ;
434+ bst_bin_t n_bins = 16 ;
435+
436+ auto lhs = MakeHostSummary ({{0 .1f , 1 .0f }, {0 .3f , 2 .0f }, {0 .5f , 1 .0f }});
437+ auto rhs = MakeHostSummary ({{0 .3f , 1 .5f }, {0 .4f , 1 .0f }, {0 .5f , 0 .5f }});
438+
439+ common::WQSummaryContainer expected;
440+ expected.Reserve (lhs.Size () + rhs.Size ());
441+ expected.CopyFrom (lhs);
442+ expected.SetCombine (rhs);
443+
444+ auto lhs_entries = CopySummaryEntries (lhs);
445+ auto rhs_entries = CopySummaryEntries (rhs);
446+
447+ dh::device_vector<SketchEntry> d_lhs{lhs_entries};
448+ dh::device_vector<SketchEntry> d_rhs{rhs_entries};
449+ dh::device_vector<size_t > lhs_ptr{0 , lhs.Size ()};
450+ dh::device_vector<size_t > rhs_ptr{0 , rhs.Size ()};
451+
452+ HostDeviceVector<FeatureType> ft;
453+ SketchContainer sketch (ft, n_bins, kCols , ctx.Device ());
454+ sketch.Merge (&ctx, dh::ToSpan (lhs_ptr), dh::ToSpan (d_lhs));
455+ sketch.Merge (&ctx, dh::ToSpan (rhs_ptr), dh::ToSpan (d_rhs));
456+
457+ std::vector<bst_idx_t > h_columns_ptr (sketch.ColumnsPtr ().size ());
458+ dh::CopyDeviceSpanToVector (&h_columns_ptr, sketch.ColumnsPtr ());
459+ auto h_data = std::vector<SketchEntry>(sketch.Data ().size ());
460+ dh::CopyDeviceSpanToVector (&h_data, sketch.Data ());
461+
462+ ASSERT_EQ (h_columns_ptr.back (), expected.Size ());
463+ auto expected_entries = expected.Entries ();
464+ ASSERT_EQ (h_data.size (), expected_entries.size ());
465+ for (std::size_t i = 0 ; i < h_data.size (); ++i) {
466+ EXPECT_FLOAT_EQ (h_data[i].value , expected_entries[i].value );
467+ EXPECT_FLOAT_EQ (h_data[i].rmin , expected_entries[i].rmin );
468+ EXPECT_FLOAT_EQ (h_data[i].rmax , expected_entries[i].rmax );
469+ EXPECT_FLOAT_EQ (h_data[i].wmin , expected_entries[i].wmin );
470+ }
471+ }
472+
373473TEST (GPUQuantile, MultiMerge) {
374474 constexpr size_t kRows = 20 , kCols = 1 ;
375475 int32_t world = 2 ;
0 commit comments