@@ -183,6 +183,11 @@ XGBOOST_DEVICE thrust::tuple<uint64_t, uint64_t> MergePartition(Span<SketchEntry
183183 return thrust::make_tuple (a_ind, k - a_ind);
184184}
185185
186+ void SketchContainer::SetCurrentColumns (Span<OffsetT const > columns_ptr) {
187+ columns_ptr_.Resize (columns_ptr.size ());
188+ CopyTo (columns_ptr_.DeviceSpan (), columns_ptr);
189+ }
190+
186191// Merge d_x and d_y into out. Because the final output depends on predicate (which
187192// summary does the output element come from) result by definition of merged rank. So we
188193// compute the partition for each output directly and customize the standard merge
@@ -373,7 +378,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
373378 return l;
374379 });
375380
376- auto d_columns_ptr_out = columns_ptr_b_. DeviceSpan ();
381+ auto d_columns_ptr_out = this -> ScratchColumns ();
377382 // thrust unique_by_key preserves the first element.
378383 auto n_uniques =
379384 dh::SegmentedUnique (ctx->CUDACtx ()->CTP (), d_columns_ptr_in.data (),
@@ -391,7 +396,7 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
391396 curt::SetDevice (ctx->Ordinal ());
392397
393398 OffsetT to_total = 0 ;
394- auto &h_columns_ptr = columns_ptr_b_ .HostVector ();
399+ auto &h_columns_ptr = columns_ptr_tmp_ .HostVector ();
395400 h_columns_ptr[0 ] = to_total;
396401 auto const &h_feature_types = feature_types_.ConstHostSpan ();
397402 for (bst_feature_t i = 0 ; i < num_columns_; ++i) {
@@ -403,16 +408,16 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
403408 to_total += length;
404409 h_columns_ptr[i + 1 ] = to_total;
405410 }
406- this ->Other ().resize (to_total);
411+ this ->Scratch ().resize (to_total);
407412
408413 auto d_columns_ptr_in = this ->columns_ptr_ .ConstDeviceSpan ();
409- auto d_columns_ptr_out = columns_ptr_b_ .ConstDeviceSpan ();
410- auto out = dh::ToSpan (this ->Other ());
414+ auto d_columns_ptr_out = columns_ptr_tmp_ .ConstDeviceSpan ();
415+ auto out = dh::ToSpan (this ->Scratch ());
411416 auto in = dh::ToSpan (this ->Current ());
412417 auto ft = this ->feature_types_ .ConstDeviceSpan ();
413418 dh::device_vector<size_t > selected_idx (out.size ());
414419 auto d_selected_idx = dh::ToSpan (selected_idx);
415- HostDeviceVector<OffsetT> selected_columns_ptr (columns_ptr_b_ .Size ());
420+ HostDeviceVector<OffsetT> selected_columns_ptr (columns_ptr_tmp_ .Size ());
416421 selected_columns_ptr.SetDevice (ctx->Device ());
417422 auto entry_from_index = [=] __device__ (size_t abs_idx) {
418423 return in[abs_idx];
@@ -427,8 +432,8 @@ void SketchContainer::Prune(Context const *ctx, std::size_t to) {
427432 d_selected_idx.data (), thrust::equal_to<size_t >{});
428433 GatherPruneEntries (Span<size_t const >{d_selected_idx.data (), n_selected}, out, entry_from_index,
429434 stream);
435+ this ->entries_ .swap (this ->entries_tmp_ );
430436 this ->columns_ptr_ .Copy (selected_columns_ptr);
431- this ->Alternate ();
432437 this ->Current ().resize (n_selected);
433438 auto d_column_scan = this ->columns_ptr_ .DeviceSpan ();
434439 HostDeviceVector<OffsetT> scan_out (d_column_scan.size ());
@@ -489,21 +494,20 @@ void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_colum
489494
490495 std::size_t new_size = this ->Current ().size () + that.size ();
491496 try {
492- this ->Other ().resize (new_size);
497+ this ->Scratch ().resize (new_size);
493498 } catch (dmlc::Error const &) {
494499 // Retry
495- this ->Other ().clear ();
496- this ->Other ().shrink_to_fit ();
497- this ->Other ().resize (new_size);
500+ this ->Scratch ().clear ();
501+ this ->Scratch ().shrink_to_fit ();
502+ this ->Scratch ().resize (new_size);
498503 }
499504
500505 CHECK_EQ (d_that_columns_ptr.size (), this ->columns_ptr_ .Size ());
501506
502507 MergeImpl (ctx, this ->Data (), this ->ColumnsPtr (), that, d_that_columns_ptr,
503- dh::ToSpan (this ->Other ()), columns_ptr_b_ .DeviceSpan ());
504- this ->columns_ptr_ . Copy (columns_ptr_b_ );
508+ dh::ToSpan (this ->Scratch ()), columns_ptr_tmp_ .DeviceSpan ());
509+ this ->CommitScratch (new_size );
505510 CHECK_EQ (this ->columns_ptr_ .Size (), num_columns_ + 1 );
506- this ->Alternate ();
507511 normalize_merged ();
508512 timer_.Stop (__func__);
509513}
0 commit comments