@@ -304,18 +304,8 @@ void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
304304void SketchContainer::Push (Context const *ctx, Span<Entry const > entries, Span<size_t > columns_ptr,
305305 common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float > weights) {
306306 curt::SetDevice (ctx->Ordinal ());
307- auto ¤t = this ->entries_ ;
308- auto &columns_ptr_out = this ->columns_ptr_ ;
309- Span<SketchEntry> out;
310- dh::device_vector<SketchEntry> cuts;
311- bool first_window = current.empty ();
312- if (!first_window) {
313- cuts.resize (total_cuts);
314- out = dh::ToSpan (cuts);
315- } else {
316- current.resize (total_cuts);
317- out = dh::ToSpan (current);
318- }
307+ dh::device_vector<SketchEntry> cuts (total_cuts);
308+ auto out = dh::ToSpan (cuts);
319309 auto ft = this ->feature_types_ .ConstDeviceSpan ();
320310 if (weights.empty ()) {
321311 auto to_sketch_entry = [] __device__ (size_t sample_idx, Span<Entry const > const &column,
@@ -340,19 +330,10 @@ void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<s
340330 PruneImpl<Entry>(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry);
341331 }
342332 auto n_uniques = this ->ScanInput (ctx, out, cuts_ptr);
343-
344- if (!first_window) {
345- CHECK_EQ (columns_ptr_out.Size (), cuts_ptr.size ());
346- out = out.subspan (0 , n_uniques);
347- this ->Merge (ctx, cuts_ptr, out);
348- } else {
349- current.resize (n_uniques);
350- columns_ptr_out.SetDevice (ctx->Device ());
351- columns_ptr_out.Resize (cuts_ptr.size ());
352-
353- auto d_cuts_ptr = columns_ptr_out.DeviceSpan ();
354- CopyTo (d_cuts_ptr, cuts_ptr);
355- }
333+ CHECK_EQ (this ->columns_ptr_ .Size (), cuts_ptr.size ());
334+ this ->Merge (ctx, cuts_ptr, out.subspan (0 , n_uniques));
335+ auto intermediate_num_cuts = static_cast <bst_idx_t >(num_bins_ * kFactor );
336+ this ->Prune (ctx, intermediate_num_cuts);
356337}
357338
358339size_t SketchContainer::ScanInput (Context const *ctx, Span<SketchEntry> entries,
0 commit comments