Skip to content
Open
202 changes: 202 additions & 0 deletions python-package/xgboost/testing/quantile_dmatrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""QuantileDMatrix related tests."""

from dataclasses import dataclass
from typing import Any, Callable, Optional

import numpy as np
import pytest
from sklearn.model_selection import train_test_split
Expand All @@ -8,6 +11,203 @@

from .data import make_batches, make_categorical

MAX_NORMALIZED_RANK_ERROR = 2.0
MAX_WEIGHTED_NORMALIZED_RANK_ERROR = 14.0
Comment on lines +14 to +15
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide some brief comments on utilities here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do a bit more rewriting - also the weighted rank error of 14 is way larger than it should be really.



@dataclass(frozen=True)
class _RankContext:
sorted_x: np.ndarray
prefix_sum: np.ndarray
total_weight: float
num_cuts: int
avg_bin_weight: float


def _to_numpy(data: Any) -> np.ndarray:
if hasattr(data, "get"):
data = data.get()
elif hasattr(data, "to_pandas"):
data = data.to_pandas()
if hasattr(data, "to_numpy"):
data = data.to_numpy()
return np.asarray(data)


def _distance_to_interval(target: float, lo: float, hi: float) -> float:
if target < lo:
return lo - target
if target > hi:
return target - hi
return 0.0


def _prepare_validation_input(
x: Any, w: Optional[Any]
) -> tuple[Any, np.ndarray, float]:
x_data = x.get() if hasattr(x, "get") else x
if hasattr(x_data, "to_pandas"):
x_data = x_data.to_pandas()

if w is None:
weights = np.ones(x_data.shape[0], dtype=np.float64)
else:
weights = _to_numpy(w).astype(np.float64, copy=False)
assert weights.ndim == 1
assert weights.shape[0] == x_data.shape[0]

max_rank_error = (
MAX_NORMALIZED_RANK_ERROR
if np.all(weights == 1.0)
else MAX_WEIGHTED_NORMALIZED_RANK_ERROR
)
return x_data, weights, max_rank_error


def _column_getter(
x_data: Any, weights: np.ndarray
) -> tuple[int, Callable[[int], tuple[np.ndarray, np.ndarray]]]:
if hasattr(x_data, "tocsc") and hasattr(x_data, "indptr"):
csc = x_data.tocsc()

def get_sparse_column(fidx: int) -> tuple[np.ndarray, np.ndarray]:
beg = int(csc.indptr[fidx])
end = int(csc.indptr[fidx + 1])
indices = csc.indices[beg:end]
column = np.asarray(csc.data[beg:end])
return column, weights[indices]

return csc.shape[1], get_sparse_column

x_dense = _to_numpy(x_data)
assert x_dense.ndim == 2

def get_dense_column(fidx: int) -> tuple[np.ndarray, np.ndarray]:
column = x_dense[:, fidx]
valid = ~np.isnan(column)
return column[valid], weights[valid]

return x_dense.shape[1], get_dense_column


def _sorted_rank_state(
column: np.ndarray, column_w: np.ndarray
) -> tuple[np.ndarray, np.ndarray, float]:
sorted_idx = np.argsort(column, kind="stable")
sorted_x = column[sorted_idx]
sorted_w = column_w[sorted_idx]
prefix_sum = np.concatenate(([0.0], np.cumsum(sorted_w, dtype=np.float64)))
return sorted_x, prefix_sum, float(prefix_sum[-1])


def _make_rank_context(
column: np.ndarray, column_w: np.ndarray, column_cuts: np.ndarray
) -> _RankContext | None:
sorted_x, prefix_sum, total_weight = _sorted_rank_state(column, column_w)
if total_weight == 0.0:
return None
return _RankContext(
sorted_x=sorted_x,
prefix_sum=prefix_sum,
total_weight=total_weight,
num_cuts=column_cuts.shape[0],
avg_bin_weight=total_weight / float(column_cuts.shape[0]),
)


def _rank_error_candidate(
cut_idx: int,
cut: float,
rank_ctx: _RankContext,
) -> tuple[float, dict[str, float | int]]:
rank_lo = float(
rank_ctx.prefix_sum[np.searchsorted(rank_ctx.sorted_x, cut, side="left")]
)
rank_hi = float(
rank_ctx.prefix_sum[np.searchsorted(rank_ctx.sorted_x, cut, side="right")]
)
target_rank = ((cut_idx + 1) * rank_ctx.total_weight) / float(rank_ctx.num_cuts)
absolute_error = _distance_to_interval(target_rank, rank_lo, rank_hi)
return absolute_error / rank_ctx.avg_bin_weight, {
"cut": cut_idx,
"absolute_error": absolute_error,
"target_rank": target_rank,
"rank_lo": rank_lo,
"rank_hi": rank_hi,
}


def _max_rank_error_for_column(
column: np.ndarray, column_w: np.ndarray, column_cuts: np.ndarray
) -> tuple[float, str]:
rank_ctx = _make_rank_context(column, column_w, column_cuts)
if rank_ctx is None:
return 0.0, ""

max_error = 0.0
max_state = {
"cut": 0,
"absolute_error": 0.0,
"target_rank": 0.0,
"rank_lo": 0.0,
"rank_hi": 0.0,
}
for cut_idx, cut in enumerate(column_cuts[:-1]):
error, state = _rank_error_candidate(cut_idx, cut, rank_ctx)
if error > max_error:
max_error = error
max_state = state

details = (
f"cut={max_state['cut']}, normalized_error={max_error}, "
f"absolute_error={max_state['absolute_error']}, "
f"target_rank={max_state['target_rank']}, rank_lo={max_state['rank_lo']}, "
f"rank_hi={max_state['rank_hi']}, total_weight={rank_ctx.total_weight}, "
f"num_cuts={column_cuts.shape[0]}"
)
return max_error, details


def _assert_feature_rank_error(
indptr: np.ndarray,
cuts: np.ndarray,
get_column: Callable[[int], tuple[np.ndarray, np.ndarray]],
fidx: int,
max_normalized_rank_error: float,
) -> None:
column, column_w = get_column(fidx)
if column.shape[0] == 0:
return

beg = int(indptr[fidx])
end = int(indptr[fidx + 1])
column_cuts = cuts[beg:end]
assert np.all(np.diff(column_cuts) >= 0.0)
if column_cuts.shape[0] <= 1:
return

max_error, details = _max_rank_error_for_column(column, column_w, column_cuts)
assert max_error <= max_normalized_rank_error, f"feature={fidx}, {details}"


def assert_cut_rank_error_within_tolerance(
indptr: np.ndarray,
cuts: np.ndarray,
x: Any,
w: Optional[Any] = None,
max_normalized_rank_error: Optional[float] = None,
) -> None:
"""Assert that every numerical feature cut stays within the allowed rank error."""
x_data, weights, default_rank_error = _prepare_validation_input(x, w)
if max_normalized_rank_error is None:
max_normalized_rank_error = default_rank_error

n_features, get_column = _column_getter(x_data, weights)
for fidx in range(n_features):
_assert_feature_rank_error(
indptr, cuts, get_column, fidx, max_normalized_rank_error
)


def check_ref_quantile_cut(device: str) -> None:
"""Check obtaining the same cut values given a reference."""
Expand All @@ -30,10 +230,12 @@ def check_ref_quantile_cut(device: str) -> None:

np.testing.assert_allclose(cut_train[0], cut_valid[0])
np.testing.assert_allclose(cut_train[1], cut_valid[1])
assert_cut_rank_error_within_tolerance(cut_train[0], cut_train[1], X_train)

Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid)
cut_valid = Xy_valid.get_quantile_cut()
assert not np.allclose(cut_train[1], cut_valid[1])
assert_cut_rank_error_within_tolerance(cut_valid[0], cut_valid[1], X_valid)


def check_categorical_strings(device: str) -> None:
Expand Down
3 changes: 3 additions & 0 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..training import train
from .data import IteratorForTest, make_batches, make_categorical
from .data_iter import CatIter
from .quantile_dmatrix import assert_cut_rank_error_within_tolerance
from .utils import Device, assert_allclose, non_increasing


Expand Down Expand Up @@ -334,11 +335,13 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
Xyw: DMatrix = QuantileDMatrix(X, y, weight=w, max_bin=max_bin)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
assert_cut_rank_error_within_tolerance(indptr, data, X, w)
# - dm
Xyw = DMatrix(X, y, weight=w)
train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
assert_cut_rank_error_within_tolerance(indptr, data, X, w)
# - ext mem
n_batches = 3
n_samples_per_batch = 256
Expand Down
4 changes: 1 addition & 3 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ constexpr float SketchContainer::kFactor;

namespace detail {
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t num_cuts = WQuantileSketch::LimitSizeLevel(num_rows, eps);
return std::min(num_cuts, num_rows);
return std::min(SketchSummaryBudget(max_bins, num_rows), num_rows);
}

size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins,
Expand Down
Loading
Loading