Skip to content

Commit 23e82ef

Browse files
committed
Fix weighted hessian quantile test references
1 parent b9ee89e commit 23e82ef

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

tests/cpp/common/test_hist_util.cu

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,23 +685,18 @@ class DeviceSketchWithHessianTest
685685
HostDeviceVector<float> const& hessian, std::vector<float> const& w,
686686
std::size_t n_elements) const {
687687
auto const& h_hess = hessian.ConstHostVector();
688-
{
689-
auto& h_weight = p_fmat->Info().weights_.HostVector();
690-
h_weight = w;
691-
}
688+
auto& h_weight = p_fmat->Info().weights_.HostVector();
689+
h_weight = w;
692690

693691
HistogramCuts cuts_hess =
694692
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
695-
ValidateCuts(cuts_hess, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
696693

697694
// merge hessian
698-
{
699-
auto& h_weight = p_fmat->Info().weights_.HostVector();
700-
ASSERT_EQ(h_weight.size(), h_hess.size());
701-
for (std::size_t i = 0; i < h_weight.size(); ++i) {
702-
h_weight[i] = w[i] * h_hess[i];
703-
}
695+
ASSERT_EQ(h_weight.size(), h_hess.size());
696+
for (std::size_t i = 0; i < h_weight.size(); ++i) {
697+
h_weight[i] = w[i] * h_hess[i];
704698
}
699+
ValidateCuts(cuts_hess, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
705700

706701
HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements);
707702
ValidateCuts(cuts_wh, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
@@ -750,7 +745,11 @@ class DeviceSketchWithHessianTest
750745
cuts_hess =
751746
DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements);
752747
// make validation easier by converting it into sample weight.
753-
p_fmat->Info().weights_.HostVector() = h_hess;
748+
p_fmat->Info().weights_.Resize(n_samples);
749+
for (std::size_t i = 0; i < h_hess.size(); ++i) {
750+
auto gidx = dh::SegmentId(Span{gptr.data(), gptr.size()}, i);
751+
p_fmat->Info().weights_.HostVector()[i] = w[gidx] * h_hess[i];
752+
}
754753
p_fmat->Info().group_ptr_.clear();
755754
ValidateCuts(cuts_hess, p_fmat.get(), n_bins, kMaxWeightedNormalizedRankError);
756755

0 commit comments

Comments
 (0)