@@ -685,23 +685,18 @@ class DeviceSketchWithHessianTest
685685 HostDeviceVector<float > const & hessian, std::vector<float > const & w,
686686 std::size_t n_elements) const {
687687 auto const & h_hess = hessian.ConstHostVector ();
688- {
689- auto & h_weight = p_fmat->Info ().weights_ .HostVector ();
690- h_weight = w;
691- }
688+ auto & h_weight = p_fmat->Info ().weights_ .HostVector ();
689+ h_weight = w;
692690
693691 HistogramCuts cuts_hess =
694692 DeviceSketchWithHessian (ctx, p_fmat.get (), n_bins, hessian.ConstDeviceSpan (), n_elements);
695- ValidateCuts (cuts_hess, p_fmat.get (), n_bins, kMaxWeightedNormalizedRankError );
696693
697694 // merge hessian
698- {
699- auto & h_weight = p_fmat->Info ().weights_ .HostVector ();
700- ASSERT_EQ (h_weight.size (), h_hess.size ());
701- for (std::size_t i = 0 ; i < h_weight.size (); ++i) {
702- h_weight[i] = w[i] * h_hess[i];
703- }
695+ ASSERT_EQ (h_weight.size (), h_hess.size ());
696+ for (std::size_t i = 0 ; i < h_weight.size (); ++i) {
697+ h_weight[i] = w[i] * h_hess[i];
704698 }
699+ ValidateCuts (cuts_hess, p_fmat.get (), n_bins, kMaxWeightedNormalizedRankError );
705700
706701 HistogramCuts cuts_wh = DeviceSketch (ctx, p_fmat.get (), n_bins, n_elements);
707702 ValidateCuts (cuts_wh, p_fmat.get (), n_bins, kMaxWeightedNormalizedRankError );
@@ -750,7 +745,11 @@ class DeviceSketchWithHessianTest
750745 cuts_hess =
751746 DeviceSketchWithHessian (ctx, p_fmat.get (), n_bins, hessian.ConstDeviceSpan (), n_elements);
752747 // make validation easier by converting it into sample weight.
753- p_fmat->Info ().weights_ .HostVector () = h_hess;
748+ p_fmat->Info ().weights_ .Resize (n_samples);
749+ for (std::size_t i = 0 ; i < h_hess.size (); ++i) {
750+ auto gidx = dh::SegmentId (Span{gptr.data (), gptr.size ()}, i);
751+ p_fmat->Info ().weights_ .HostVector ()[i] = w[gidx] * h_hess[i];
752+ }
754753 p_fmat->Info ().group_ptr_ .clear ();
755754 ValidateCuts (cuts_hess, p_fmat.get (), n_bins, kMaxWeightedNormalizedRankError );
756755
0 commit comments