Skip to content

Find similar observations using leaf node matching #11919

@mfdel

Description

@mfdel

I'm proposing a feature to find which training samples are "similar" to a prediction sample, based on whether they end up in the same leaf nodes across trees.

The idea is simple: if two observations consistently land in the same leaves across many trees, the model "sees" them as similar.

I already built a function to do this for my use case. I was doing a prediction to assess sales potential for some geographies and the model was predicting way too high for a zip code. I couldn't figure out why from the features or feature importances alone.

So I wrote a wrapper that checks: for this zip code, which training observations land in the same leaf nodes most often? Turns out a tiny zip code on the other side of the world matched 83% of the time. Looking at the features, they were strangely similar. It was an outlier and later I excluded it form the training.

I would never have found this by looking at Euclidean distance or raw feature values. The model knew these were similar. I just needed a way to ask it.

This is how it works

  1. Use pred_leaf=True to get leaf indices for each tree
  2. For each tree, check if query and reference land in same leaf (boolean)
  3. Average across trees → similarity score between 0 and 1 (I also weight trees by the variance of their leaf predictions, meaning that trees where all leaves predict ~same value aren't very discriminative.)

What the API could look like

# High-level
similar_idx, scores = model.find_similar(
    query=X_query,
    reference=X_train,
    k=5,
)

# Or lower-level on booster
query_leaves = booster.predict(DMatrix(X_query), pred_leaf=True)
ref_leaves = booster.predict(DMatrix(X_train), pred_leaf=True)
similarity = booster.compute_leaf_similarity(query_leaves, ref_leaves)

Why this is useful

  • Debugging predictions: "why is this prediction so high?" → find similar training samples and inspect them
  • Finding bad training data: outliers in training can affect predictions in unexpected places
  • Explaining to stakeholders: "this prediction is similar to these 5 historical cases" is easier to trust than a black box number

I haven't used it but I found out Random Forests have something similar to this, proximity matrices. Would be nice to have this in XGBoost.

Questions

  • Is this something that fits in XGBoost's scope, or better as a separate utility?
  • Any concerns about scaling to large datasets? (the leaf prediction is fast, similarity is just broadcasting)
  • Happy to put together a PR if there's interest

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions