@@ -278,7 +278,8 @@ inline HistogramCuts DeviceSketch(
278278template <typename AdapterBatch>
279279void ProcessSlidingWindow (Context const * ctx, AdapterBatch const & batch, MetaInfo const & info,
280280 size_t n_features, size_t begin, size_t end, float missing,
281- SketchContainer* sketch_container, int num_cuts) {
281+ SketchContainer* sketch_container, int num_cuts,
282+ bst_idx_t approx_n_samples) {
282283 // Copy current subset of valid elements into temporary storage and sort
283284 dh::device_vector<Entry> sorted_entries;
284285 dh::caching_device_vector<size_t > column_sizes_scan;
@@ -303,7 +304,7 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
303304 auto const & h_cuts_ptr = cuts_ptr.HostVector ();
304305 // Extract the cuts from all columns concurrently
305306 sketch_container->Push (ctx, dh::ToSpan (sorted_entries), dh::ToSpan (column_sizes_scan), d_cuts_ptr,
306- h_cuts_ptr.back ());
307+ h_cuts_ptr.back (), approx_n_samples );
307308
308309 sorted_entries.clear ();
309310 sorted_entries.shrink_to_fit ();
@@ -313,7 +314,7 @@ template <typename Batch>
313314void ProcessWeightedSlidingWindow (Context const * ctx, Batch batch, MetaInfo const & info,
314315 int num_cuts_per_feature, bool is_ranking, float missing,
315316 size_t columns, size_t begin, size_t end,
316- SketchContainer* sketch_container) {
317+ SketchContainer* sketch_container, bst_idx_t approx_n_samples ) {
317318 curt::SetDevice (ctx->Ordinal ());
318319 info.weights_ .SetDevice (ctx->Device ());
319320 auto weights = info.weights_ .ConstDeviceSpan ();
@@ -379,7 +380,7 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
379380
380381 // Extract cuts
381382 sketch_container->Push (ctx, dh::ToSpan (sorted_entries), dh::ToSpan (column_sizes_scan), d_cuts_ptr,
382- h_cuts_ptr.back (), dh::ToSpan (temp_weights));
383+ h_cuts_ptr.back (), approx_n_samples, dh::ToSpan (temp_weights));
383384 sorted_entries.clear ();
384385 sorted_entries.shrink_to_fit ();
385386}
@@ -431,10 +432,10 @@ void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, Me
431432 if (weighted) {
432433 ProcessWeightedSlidingWindow (ctx, batch, info, num_cuts_per_feature,
433434 HostSketchContainer::UseGroup (info), missing, num_cols, begin,
434- end, sketch_container);
435+ end, sketch_container, approx_n_samples );
435436 } else {
436437 ProcessSlidingWindow (ctx, batch, info, num_cols, begin, end, missing, sketch_container,
437- num_cuts_per_feature);
438+ num_cuts_per_feature, approx_n_samples );
438439 }
439440 begin += sketch_batch_num_elements;
440441 }
0 commit comments