Skip to content

[c++] Add survival_cox objective for Cox proportional hazards modelling#7212

Open
ohines wants to merge 28 commits intolightgbm-org:masterfrom
ohines:oh-survival
Open

[c++] Add survival_cox objective for Cox proportional hazards modelling#7212
ohines wants to merge 28 commits intolightgbm-org:masterfrom
ohines:oh-survival

Conversation

@ohines
Copy link
Copy Markdown

@ohines ohines commented Mar 27, 2026

Overview:

  • Adds Cox Proportional Hazards loss requested in Cox Proportional Hazard Regression #1837 and with several up-votes in Feature Requests & Voting Hub #2302.
  • Also added a metric to compute Harrell's concordance (C-index) popular in survival analysis.
  • These can be implemented with custom losses and metrics, but the computation to pre-sort the data and compute Breslow baseline hazards is a bit fiddly (especially with tied times), so a built in implementation is nice.
  • For context: I was using a custom python+numba implementation in a data analysis, which motivates this PR.

Naming:

  • I wasn't sure what to call the objective. I went for survival_cox with cox, cox_ph, and survival as aliases.
  • Similar for the negative partial log likelihood metric. I went for survival_cox_nll with aliases cox_nll, and survival_nll.

Related:

@ohines ohines changed the title Add survival_cox objective for Cox proportional hazards modelling [c++] Add survival_cox objective for Cox proportional hazards modelling Mar 27, 2026
Copy link
Copy Markdown
Member

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your interest in LightGBM. Someone will review this when we have time.

Until then, please:

  • update this branch to latest master
  • fix all the linting issues with pre-commit run --all-files

import lightgbm as lgb

# Load FLCHAIN dataset (serum free light chain and mortality)
data = fetch_openml("flchain", version=1, as_frame=True, parser="auto")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AppVeyor builds are failing like this:

TypeError: fetch_openml() got an unexpected keyword argument 'parser'

https://ci.appveyor.com/project/guolinke/lightgbm/builds/53791302/job/oj3cfvbuifsjc7au?fullLog=true

Those jobs use a very old scikit-learn (1.0), which I guess must not have had that. Can you please figure out a more portable pattern? A different dataset, omitting the parser argument, something like that?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I wasn't sure how to debug the appveyor failing tests.
You are right - parser is not required so removed it

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AppVeyor fails again with

sklearn.datasets._openml.OpenMLError: Dataset flchain with version 1 not found

I will look into it. It seems that fetch_openml is marked as experimental in the scikit learn version 1.0 docs. Is there a particular reason that we test this version which was released in 2021?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we try to support a wide range of lightgbm's main dependencies, for the benefit of users who can't easily upgrade to newer versions (e.g. they're using managed environments like Databricks notebooks or constrained to older operating systems).

We'd prefer to have a compelling reason to bump a runtime floor, and "makes this example in documentation easier to test" isn't that compelling, in my opinion.

That said we do already have a Linux job testing an even older scikit-learn:

scikit-learn==0.24.2

So I wouldn't be opposed to updating the pin for Python 3.9 environments like the one on Appveyor. That could be done here:

scikit-learn=1.0.*

I'd support trying to bump that up to a newer scikit-learn if you'd like. But it'll probably require pinning more than just scikit-learn, so might take a bit of trial and error.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this on my local machine with python=3.10.20 and sklearn=1.0.2 and it works fine. Is it possible that the runner does not have access to open ml?

These are the api calls that are made (generated by adding a print statement here)

downloading data from https://openml.org/api/v1/json/data/list/data_name/flchain/limit/2/data_version/1
downloading data from https://openml.org/api/v1/json/data/46161
downloading data from https://openml.org/api/v1/json/data/features/46161
downloading data from https://openml.org/api/v1/json/data/qualities/46161
downloading data from https://openml.org/data/v1/download/22120605

In any case I can just replace the example to use synthetic data.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in eb083b1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this on my local machine with python=3.10.20 and sklearn=1.0.2 and it works fine. Is it possible that the runner does not have access to open ml?

That job uses Python 3.9, not 3.10. It's the standard Appveyor runner for open source projects and should have full access to the internet.

I suspect that maybe that "not found" error is actually from a broad try-catch and that something else in the environment (like some other dependency version) is causing it to fail.

The approach with synthetic data looks good to me!

Thanks for working through that.

Copy link
Copy Markdown
Member

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I went through this more thoroughly and left some more suggestions, please do say them.

I don't feel qualified to review the objective and metric implementations... once all my other suggestions are addressed, I can try to recruit another maintainer (or an outside reviewer) to look those over.



@lru_cache(maxsize=None)
def load_survival():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def load_survival():
def make_survival(*, n_samples, random_state):

Since this is generating random data, not loading an existing dataset, let's follow the existing conventions here and call it make_* instead of load_*.

And can you please parameterize at least the number of samples and the random seed?

Comment on lines +40 to +47
n = 500
p = 5
censoring_rate = 0.3
rng = np.random.RandomState(seed=42)
X = rng.randn(n, p)
log_hazard = X[:, 0] + 0.5 * X[:, 1]
times = rng.exponential(np.exp(-log_hazard))
censor_times = rng.exponential(np.median(times) / censoring_rate, n)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n = 500
p = 5
censoring_rate = 0.3
rng = np.random.RandomState(seed=42)
X = rng.randn(n, p)
log_hazard = X[:, 0] + 0.5 * X[:, 1]
times = rng.exponential(np.exp(-log_hazard))
censor_times = rng.exponential(np.median(times) / censoring_rate, n)
n_features = 5
censoring_rate = 0.3
rng = np.random.RandomState(seed=42)
X = rng.randn(n_samples, n_features)
log_hazard = X[:, 0] + 0.5 * X[:, 1]
times = rng.exponential(np.exp(-log_hazard))
censor_times = rng.exponential(np.median(times) / censoring_rate, n_samples)

Let's please use more informative variable names, and match the names used in other functions in this file.

Comment on lines +85 to +88
} else if (type == std::string("survival_cox_nll")) {
Log::Warning("Metric survival_cox_nll is not implemented in cuda version. Fall back to evaluation on CPU.");
return new CoxNLLMetric(config);
} else if (type == std::string("concordance_index") || type == std::string("c_index")) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this mapping in the R package as well:

.METRICS_HIGHER_BETTER <- function() {

If you're comfortable writing R code we'd welcome new tests in the R package too, but at a minimum that mapping should be updated so the R package's early stopping behavior will be correct.


- survival analysis application

- ``survival_cox``, `Cox proportional hazards <https://en.wikipedia.org/wiki/Proportional_hazards_model>`__ partial likelihood with Breslow's method for ties, aliases: ``survival``, ``cox``, ``cox_ph``
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these aliases used in other projects or research?

If not, let's please not use any aliases for this objective. Aliases add complexity and maintenance burden, and I'd especially like to avoid committing survival like this in case other survival objectives are added in the future.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about removing the name survival.

  • In XGBoost the relavant objective and metric aresurvival:cox and cox-nloglik.
  • In Scikit-survival the relevant function is CoxPHSurvivalAnalysis
  • In R Survival package the relevant function is coxph with ties=“breslow”
  • In the Lifelines package the relevant class is CoxPHFitter
  • In stats theory, the metric is sometimes referred to as a "partial likelihood"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for those links! That's exactly the type of thing I was looking for. Based on that, I'm happy with dropping survival but keeping cox and cox_ph.

Comment on lines +8 to +23
def load_survival():
"""Generate synthetic survival data with signed-time label convention."""
n = 500
p = 5
censoring_rate = 0.3
rng = np.random.RandomState(seed=42)
X = rng.randn(n, p)
log_hazard = X[:, 0] + 0.1 * X[:, 1]
times = rng.exponential(np.exp(-log_hazard))
censor_times = rng.exponential(np.median(times) / censoring_rate, n)
observed = times <= censor_times
y = np.where(observed, np.minimum(times, censor_times), -censor_times)
return X.astype(np.float64), y.astype(np.float64)


X, y = load_survival()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def load_survival():
"""Generate synthetic survival data with signed-time label convention."""
n = 500
p = 5
censoring_rate = 0.3
rng = np.random.RandomState(seed=42)
X = rng.randn(n, p)
log_hazard = X[:, 0] + 0.1 * X[:, 1]
times = rng.exponential(np.exp(-log_hazard))
censor_times = rng.exponential(np.median(times) / censoring_rate, n)
observed = times <= censor_times
y = np.where(observed, np.minimum(times, censor_times), -censor_times)
return X.astype(np.float64), y.astype(np.float64)
X, y = load_survival()
def make_survival(*, n_samples, n_features, censoring_rate, random_state):
"""Generate synthetic survival data with signed-time label convention."""
rng = np.random.RandomState(seed=random_state)
X = rng.randn(n_samples, n_features)
log_hazard = X[:, 0] + 0.1 * X[:, 1]
times = rng.exponential(np.exp(-log_hazard))
censor_times = rng.exponential(np.median(times) / censoring_rate, n_features)
observed = times <= censor_times
y = np.where(observed, np.minimum(times, censor_times), -censor_times)
return X.astype(np.float64), y.astype(np.float64)
X, y = load_survival(n_samples=500, n_features=5, censoring_rate=0.3, random_state=42)

Similar to my comments on the test code... let's please use more informative variable names, and let's make some of these things configurable so people can experiment with different configurations.

Comment on lines +466 to +467
assert "survival_cox_nll" in evals_result["val"]
assert "concordance_index" in evals_result["val"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert "survival_cox_nll" in evals_result["val"]
assert "concordance_index" in evals_result["val"]
assert set(evals_result["val"].keys()) == {"survival_cox_nll", "concordance_index"}

Let's make this stricter and test for exact equivalence. As I think you noticed, LightGBM automatically adds a metric based on the loss function you choose. This stricter test could catch problems like the wrong metric accidentally being added when the survival_cox objective is used.

assert "concordance_index" in evals_result["val"]
assert len(evals_result["val"]["survival_cox_nll"]) == 50
# concordance index should be above random (0.5) for this easy problem
assert evals_result["val"]["concordance_index"][-1] > 0.55
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also add a test on the value of survival_cox_nll? If that metric just returned -1000 for every iteration right now, no test failure would alert us to that.

import lightgbm as lgb

# Load FLCHAIN dataset (serum free light chain and mortality)
data = fetch_openml("flchain", version=1, as_frame=True, parser="auto")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this on my local machine with python=3.10.20 and sklearn=1.0.2 and it works fine. Is it possible that the runner does not have access to open ml?

That job uses Python 3.9, not 3.10. It's the standard Appveyor runner for open source projects and should have full access to the internet.

I suspect that maybe that "not found" error is actually from a broad try-catch and that something else in the environment (like some other dependency version) is causing it to fail.

The approach with synthetic data looks good to me!

Thanks for working through that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants