-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Find similar observations using leaf node matching #11919
Description
I'm proposing a feature to find which training samples are "similar" to a prediction sample, based on whether they end up in the same leaf nodes across trees.
The idea is simple: if two observations consistently land in the same leaves across many trees, the model "sees" them as similar.
I already built a function to do this for my use case. I was doing a prediction to assess sales potential for some geographies and the model was predicting way too high for a zip code. I couldn't figure out why from the features or feature importances alone.
So I wrote a wrapper that checks: for this zip code, which training observations land in the same leaf nodes most often? Turns out a tiny zip code on the other side of the world matched 83% of the time. Looking at the features, they were strangely similar. It was an outlier and later I excluded it form the training.
I would never have found this by looking at Euclidean distance or raw feature values. The model knew these were similar. I just needed a way to ask it.
This is how it works
- Use
pred_leaf=Trueto get leaf indices for each tree - For each tree, check if query and reference land in same leaf (boolean)
- Average across trees → similarity score between 0 and 1 (I also weight trees by the variance of their leaf predictions, meaning that trees where all leaves predict ~same value aren't very discriminative.)
What the API could look like
# High-level
similar_idx, scores = model.find_similar(
query=X_query,
reference=X_train,
k=5,
)
# Or lower-level on booster
query_leaves = booster.predict(DMatrix(X_query), pred_leaf=True)
ref_leaves = booster.predict(DMatrix(X_train), pred_leaf=True)
similarity = booster.compute_leaf_similarity(query_leaves, ref_leaves)Why this is useful
- Debugging predictions: "why is this prediction so high?" → find similar training samples and inspect them
- Finding bad training data: outliers in training can affect predictions in unexpected places
- Explaining to stakeholders: "this prediction is similar to these 5 historical cases" is easier to trust than a black box number
I haven't used it but I found out Random Forests have something similar to this, proximity matrices. Would be nice to have this in XGBoost.
Questions
- Is this something that fits in XGBoost's scope, or better as a separate utility?
- Any concerns about scaling to large datasets? (the leaf prediction is fast, similarity is just broadcasting)
- Happy to put together a PR if there's interest