[c++] Add survival_cox objective for Cox proportional hazards modelling#7212
[c++] Add survival_cox objective for Cox proportional hazards modelling#7212ohines wants to merge 28 commits intolightgbm-org:masterfrom
survival_cox objective for Cox proportional hazards modelling#7212Conversation
survival_cox objective for Cox proportional hazards modellingsurvival_cox objective for Cox proportional hazards modelling
jameslamb
left a comment
There was a problem hiding this comment.
Thanks for your interest in LightGBM. Someone will review this when we have time.
Until then, please:
- update this branch to latest
master - fix all the linting issues with
pre-commit run --all-files
| import lightgbm as lgb | ||
|
|
||
| # Load FLCHAIN dataset (serum free light chain and mortality) | ||
| data = fetch_openml("flchain", version=1, as_frame=True, parser="auto") |
There was a problem hiding this comment.
The AppVeyor builds are failing like this:
TypeError: fetch_openml() got an unexpected keyword argument 'parser'
https://ci.appveyor.com/project/guolinke/lightgbm/builds/53791302/job/oj3cfvbuifsjc7au?fullLog=true
Those jobs use a very old scikit-learn (1.0), which I guess must not have had that. Can you please figure out a more portable pattern? A different dataset, omitting the parser argument, something like that?
There was a problem hiding this comment.
Thanks for pointing this out. I wasn't sure how to debug the appveyor failing tests.
You are right - parser is not required so removed it
There was a problem hiding this comment.
AppVeyor fails again with
sklearn.datasets._openml.OpenMLError: Dataset flchain with version 1 not found
I will look into it. It seems that fetch_openml is marked as experimental in the scikit learn version 1.0 docs. Is there a particular reason that we test this version which was released in 2021?
There was a problem hiding this comment.
In general, we try to support a wide range of lightgbm's main dependencies, for the benefit of users who can't easily upgrade to newer versions (e.g. they're using managed environments like Databricks notebooks or constrained to older operating systems).
We'd prefer to have a compelling reason to bump a runtime floor, and "makes this example in documentation easier to test" isn't that compelling, in my opinion.
That said we do already have a Linux job testing an even older scikit-learn:
So I wouldn't be opposed to updating the pin for Python 3.9 environments like the one on Appveyor. That could be done here:
LightGBM/.ci/conda-envs/ci-core-py39.txt
Line 31 in a7d00a9
I'd support trying to bump that up to a newer scikit-learn if you'd like. But it'll probably require pinning more than just scikit-learn, so might take a bit of trial and error.
There was a problem hiding this comment.
I tested this on my local machine with python=3.10.20 and sklearn=1.0.2 and it works fine. Is it possible that the runner does not have access to open ml?
These are the api calls that are made (generated by adding a print statement here)
downloading data from https://openml.org/api/v1/json/data/list/data_name/flchain/limit/2/data_version/1
downloading data from https://openml.org/api/v1/json/data/46161
downloading data from https://openml.org/api/v1/json/data/features/46161
downloading data from https://openml.org/api/v1/json/data/qualities/46161
downloading data from https://openml.org/data/v1/download/22120605
In any case I can just replace the example to use synthetic data.
There was a problem hiding this comment.
I tested this on my local machine with python=3.10.20 and sklearn=1.0.2 and it works fine. Is it possible that the runner does not have access to open ml?
That job uses Python 3.9, not 3.10. It's the standard Appveyor runner for open source projects and should have full access to the internet.
I suspect that maybe that "not found" error is actually from a broad try-catch and that something else in the environment (like some other dependency version) is causing it to fail.
The approach with synthetic data looks good to me!
Thanks for working through that.
jameslamb
left a comment
There was a problem hiding this comment.
Thanks, I went through this more thoroughly and left some more suggestions, please do say them.
I don't feel qualified to review the objective and metric implementations... once all my other suggestions are addressed, I can try to recruit another maintainer (or an outside reviewer) to look those over.
tests/python_package_test/utils.py
Outdated
|
|
||
|
|
||
| @lru_cache(maxsize=None) | ||
| def load_survival(): |
There was a problem hiding this comment.
| def load_survival(): | |
| def make_survival(*, n_samples, random_state): |
Since this is generating random data, not loading an existing dataset, let's follow the existing conventions here and call it make_* instead of load_*.
And can you please parameterize at least the number of samples and the random seed?
tests/python_package_test/utils.py
Outdated
| n = 500 | ||
| p = 5 | ||
| censoring_rate = 0.3 | ||
| rng = np.random.RandomState(seed=42) | ||
| X = rng.randn(n, p) | ||
| log_hazard = X[:, 0] + 0.5 * X[:, 1] | ||
| times = rng.exponential(np.exp(-log_hazard)) | ||
| censor_times = rng.exponential(np.median(times) / censoring_rate, n) |
There was a problem hiding this comment.
| n = 500 | |
| p = 5 | |
| censoring_rate = 0.3 | |
| rng = np.random.RandomState(seed=42) | |
| X = rng.randn(n, p) | |
| log_hazard = X[:, 0] + 0.5 * X[:, 1] | |
| times = rng.exponential(np.exp(-log_hazard)) | |
| censor_times = rng.exponential(np.median(times) / censoring_rate, n) | |
| n_features = 5 | |
| censoring_rate = 0.3 | |
| rng = np.random.RandomState(seed=42) | |
| X = rng.randn(n_samples, n_features) | |
| log_hazard = X[:, 0] + 0.5 * X[:, 1] | |
| times = rng.exponential(np.exp(-log_hazard)) | |
| censor_times = rng.exponential(np.median(times) / censoring_rate, n_samples) |
Let's please use more informative variable names, and match the names used in other functions in this file.
| } else if (type == std::string("survival_cox_nll")) { | ||
| Log::Warning("Metric survival_cox_nll is not implemented in cuda version. Fall back to evaluation on CPU."); | ||
| return new CoxNLLMetric(config); | ||
| } else if (type == std::string("concordance_index") || type == std::string("c_index")) { |
There was a problem hiding this comment.
Please update this mapping in the R package as well:
LightGBM/R-package/R/metrics.R
Line 9 in a7d00a9
If you're comfortable writing R code we'd welcome new tests in the R package too, but at a minimum that mapping should be updated so the R package's early stopping behavior will be correct.
docs/Parameters.rst
Outdated
|
|
||
| - survival analysis application | ||
|
|
||
| - ``survival_cox``, `Cox proportional hazards <https://en.wikipedia.org/wiki/Proportional_hazards_model>`__ partial likelihood with Breslow's method for ties, aliases: ``survival``, ``cox``, ``cox_ph`` |
There was a problem hiding this comment.
Are these aliases used in other projects or research?
If not, let's please not use any aliases for this objective. Aliases add complexity and maintenance burden, and I'd especially like to avoid committing survival like this in case other survival objectives are added in the future.
There was a problem hiding this comment.
I agree about removing the name survival.
- In XGBoost the relavant objective and metric are
survival:coxandcox-nloglik. - In Scikit-survival the relevant function is
CoxPHSurvivalAnalysis - In R Survival package the relevant function is
coxphwithties=“breslow” - In the Lifelines package the relevant class is
CoxPHFitter - In stats theory, the metric is sometimes referred to as a "partial likelihood"
There was a problem hiding this comment.
Awesome, thanks for those links! That's exactly the type of thing I was looking for. Based on that, I'm happy with dropping survival but keeping cox and cox_ph.
| def load_survival(): | ||
| """Generate synthetic survival data with signed-time label convention.""" | ||
| n = 500 | ||
| p = 5 | ||
| censoring_rate = 0.3 | ||
| rng = np.random.RandomState(seed=42) | ||
| X = rng.randn(n, p) | ||
| log_hazard = X[:, 0] + 0.1 * X[:, 1] | ||
| times = rng.exponential(np.exp(-log_hazard)) | ||
| censor_times = rng.exponential(np.median(times) / censoring_rate, n) | ||
| observed = times <= censor_times | ||
| y = np.where(observed, np.minimum(times, censor_times), -censor_times) | ||
| return X.astype(np.float64), y.astype(np.float64) | ||
|
|
||
|
|
||
| X, y = load_survival() |
There was a problem hiding this comment.
| def load_survival(): | |
| """Generate synthetic survival data with signed-time label convention.""" | |
| n = 500 | |
| p = 5 | |
| censoring_rate = 0.3 | |
| rng = np.random.RandomState(seed=42) | |
| X = rng.randn(n, p) | |
| log_hazard = X[:, 0] + 0.1 * X[:, 1] | |
| times = rng.exponential(np.exp(-log_hazard)) | |
| censor_times = rng.exponential(np.median(times) / censoring_rate, n) | |
| observed = times <= censor_times | |
| y = np.where(observed, np.minimum(times, censor_times), -censor_times) | |
| return X.astype(np.float64), y.astype(np.float64) | |
| X, y = load_survival() | |
| def make_survival(*, n_samples, n_features, censoring_rate, random_state): | |
| """Generate synthetic survival data with signed-time label convention.""" | |
| rng = np.random.RandomState(seed=random_state) | |
| X = rng.randn(n_samples, n_features) | |
| log_hazard = X[:, 0] + 0.1 * X[:, 1] | |
| times = rng.exponential(np.exp(-log_hazard)) | |
| censor_times = rng.exponential(np.median(times) / censoring_rate, n_features) | |
| observed = times <= censor_times | |
| y = np.where(observed, np.minimum(times, censor_times), -censor_times) | |
| return X.astype(np.float64), y.astype(np.float64) | |
| X, y = load_survival(n_samples=500, n_features=5, censoring_rate=0.3, random_state=42) |
Similar to my comments on the test code... let's please use more informative variable names, and let's make some of these things configurable so people can experiment with different configurations.
| assert "survival_cox_nll" in evals_result["val"] | ||
| assert "concordance_index" in evals_result["val"] |
There was a problem hiding this comment.
| assert "survival_cox_nll" in evals_result["val"] | |
| assert "concordance_index" in evals_result["val"] | |
| assert set(evals_result["val"].keys()) == {"survival_cox_nll", "concordance_index"} |
Let's make this stricter and test for exact equivalence. As I think you noticed, LightGBM automatically adds a metric based on the loss function you choose. This stricter test could catch problems like the wrong metric accidentally being added when the survival_cox objective is used.
| assert "concordance_index" in evals_result["val"] | ||
| assert len(evals_result["val"]["survival_cox_nll"]) == 50 | ||
| # concordance index should be above random (0.5) for this easy problem | ||
| assert evals_result["val"]["concordance_index"][-1] > 0.55 |
There was a problem hiding this comment.
Can you please also add a test on the value of survival_cox_nll? If that metric just returned -1000 for every iteration right now, no test failure would alert us to that.
| import lightgbm as lgb | ||
|
|
||
| # Load FLCHAIN dataset (serum free light chain and mortality) | ||
| data = fetch_openml("flchain", version=1, as_frame=True, parser="auto") |
There was a problem hiding this comment.
I tested this on my local machine with python=3.10.20 and sklearn=1.0.2 and it works fine. Is it possible that the runner does not have access to open ml?
That job uses Python 3.9, not 3.10. It's the standard Appveyor runner for open source projects and should have full access to the internet.
I suspect that maybe that "not found" error is actually from a broad try-catch and that something else in the environment (like some other dependency version) is causing it to fail.
The approach with synthetic data looks good to me!
Thanks for working through that.
Co-authored-by: James Lamb <jaylamb20@gmail.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
Overview:
Naming:
survival_coxwithcox,cox_ph, andsurvivalas aliases.survival_cox_nllwith aliasescox_nll, andsurvival_nll.Related: