Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R-package/R/metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
, "cross_entropy" = FALSE
, "cross_entropy_lambda" = FALSE
, "kullback_leibler" = FALSE
, "survival_cox_nll" = FALSE
, "concordance_index" = TRUE
)
)
}
75 changes: 75 additions & 0 deletions R-package/tests/testthat/test_survival.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
.make_survival <- function(n_samples = 500L, n_features = 5L, random_state = 0L) {
set.seed(random_state)
X <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features)
log_hazard <- X[, 1L] + 0.5 * X[, 2L]
times <- rexp(n_samples, rate = exp(log_hazard))
censoring_rate <- 0.3
censor_times <- rexp(n_samples, rate = censoring_rate / median(times))
y <- pmin(times, censor_times)
censored <- censor_times < times
y[censored] <- -y[censored]
list(X = X, y = y)
}

test_that("survival_cox with lgb.train() works as expected", {
surv <- .make_survival()
n_train <- 375L
n <- nrow(surv$X)
dtrain <- lgb.Dataset(surv$X[1L:n_train, ], label = surv$y[1L:n_train])
dval <- lgb.Dataset(
surv$X[(n_train + 1L):n, ]
, label = surv$y[(n_train + 1L):n]
, reference = dtrain
)

params <- list(
objective = "survival_cox"
, metric = list("survival_cox_nll", "concordance_index")
, num_leaves = 8L
, seed = 708L
, num_threads = .LGB_MAX_THREADS
, deterministic = TRUE
, force_row_wise = TRUE
, verbose = .LGB_VERBOSITY
)
model <- lgb.train(
params = params
, data = dtrain
, nrounds = 10L
, valids = list(val = dval)
, record = TRUE
)

# check that both metrics are present in expected order
eval_results <- model$eval_valid()
expect_equal(length(eval_results), 2L)
expect_equal(eval_results[[1L]]$name, "survival_cox_nll")
expect_equal(eval_results[[2L]]$name, "concordance_index")

# check higher_better flags
expect_false(eval_results[[1L]]$higher_better)
expect_true(eval_results[[2L]]$higher_better)

# extract per-round metric values
losses <- unlist(model$record_evals[["val"]][["survival_cox_nll"]][["eval"]])
c_indices <- unlist(model$record_evals[["val"]][["concordance_index"]][["eval"]])
expect_equal(length(losses), 10L)
expect_equal(length(c_indices), 10L)

# check that all metrics are finite
expect_true(all(is.finite(losses)))
expect_true(all(is.finite(c_indices)))

# check that metrics are in a reasonable range for this problem
expect_true(all(losses > 3.7 & losses < 4.1))
expect_true(all(c_indices > 0.6 & c_indices < 0.8))

# check that validation loss generally improves (last < first)
expect_true(losses[1L] > losses[10L])

# check that concordance index and loss improves for at least half the rounds
loss_improvements <- sum(diff(losses) < 0L)
ci_improvements <- sum(diff(c_indices) > 0L)
expect_true(loss_improvements >= 5L)
expect_true(ci_improvements >= 5L)
})
12 changes: 11 additions & 1 deletion docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Core Parameters

- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions

- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">&#x1F517;&#xFE0E;</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gamma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``cross_entropy``, ``cross_entropy_lambda``, ``lambdarank``, ``rank_xendcg``, aliases: ``objective_type``, ``app``, ``application``, ``loss``
- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">&#x1F517;&#xFE0E;</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gamma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``cross_entropy``, ``cross_entropy_lambda``, ``lambdarank``, ``rank_xendcg``, ``survival_cox``, aliases: ``objective_type``, ``app``, ``application``, ``loss``

- regression application

Expand Down Expand Up @@ -170,6 +170,12 @@ Core Parameters

- label should be ``int`` type, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)

- survival analysis application

- ``survival_cox``, `Cox proportional hazards <https://en.wikipedia.org/wiki/Proportional_hazards_model>`__ partial likelihood with Breslow's method for ties, aliases: ``cox``, ``cox_ph``

- label encodes censoring via sign: positive value = event time, negative value = censored time

- custom objective function (gradients and hessians not computed directly by LightGBM)

- ``custom``
Expand Down Expand Up @@ -1279,6 +1285,10 @@ Metric Parameters

- ``kullback_leibler``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kldiv``

- ``survival_cox_nll``, negative partial log-likelihood for `Cox proportional hazards <https://en.wikipedia.org/wiki/Proportional_hazards_model>`__ model, aliases: ``cox_nll``

- ``concordance_index``, `Harrell's concordance index <https://doi.org/10.1002/(SICI)1097-0258(19960229)15:4%3C361::AID-SIM168%3E3.0.CO;2-4>`__ for survival models, aliases: ``c_index``

- support multiple metrics, separated by ``,``

- ``metric_freq`` :raw-html:`<a id="metric_freq" title="Permalink to this parameter" href="#metric_freq">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``output_freq``, constraints: ``metric_freq > 0``
Expand Down
4 changes: 4 additions & 0 deletions examples/python-guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ Examples include:
- Plot split value histogram
- Plot one specified tree
- Plot one specified tree with Graphviz
- [survival_example.py](https://github.com/lightgbm-org/LightGBM/blob/master/examples/python-guide/survival_example.py)
- Construct Dataset
- Use objective `survival_cox` for Cox proportional hazards survival analysis
- Evaluate with `survival_cox_nll` and `concordance_index` metrics
- [dataset_from_multi_hdf5.py](https://github.com/lightgbm-org/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py)
- Construct Dataset from multiple HDF5 files
- Avoid loading all data into memory
69 changes: 69 additions & 0 deletions examples/python-guide/survival_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# coding: utf-8
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state

import lightgbm as lgb


def make_survival(*, n_samples=500, n_features=5, random_state=0):
"""Generate synthetic survival data with signed-time label convention.

Parameters
----------
n_samples : int, optional (default=500)
Number of samples to generate.
n_features : int, optional (default=5)
Number of features to generate.
random_state : int, optional (default=0)
Random seed.

Returns
-------
X : 2-d np.ndarray of shape = [n_samples, n_features]
Input feature matrix.
y : 1-d np.array of shape = [n_samples]
Survival times.
"""
censoring_rate = 0.3
rnd_generator = check_random_state(random_state)
X = rnd_generator.randn(n_samples, n_features)
log_hazard = X[:, 0] + 0.5 * X[:, 1]
times = rnd_generator.exponential(np.exp(-log_hazard))
censor_times = rnd_generator.exponential(np.median(times) / censoring_rate, n_samples)
observed = times <= censor_times
y = np.where(observed, times, -censor_times)
return X.astype(np.float64), y.astype(np.float64)


X, y = make_survival()

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

lgb_train = lgb.Dataset(X_train, label=y_train)
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train)

params = {
"objective": "survival_cox",
"metric": ["survival_cox_nll", "concordance_index"],
"num_leaves": 10,
"learning_rate": 0.05,
"verbose": 0,
}

evals_result = {}
gbm = lgb.train(
params,
lgb_train,
num_boost_round=200,
valid_sets=[lgb_val],
valid_names=["val"],
callbacks=[
lgb.early_stopping(stopping_rounds=5, first_metric_only=True),
lgb.record_evaluation(evals_result),
],
)

# Predictions are log-hazard ratios (higher = more risk)
preds = gbm.predict(X_val, num_iteration=gbm.best_iteration)
print(f"\nPrediction range: [{preds.min():.3f}, {preds.max():.3f}]")
13 changes: 12 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct Config {
// [no-automatically-extract]
// [no-save]
// type = enum
// options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, cross_entropy, cross_entropy_lambda, lambdarank, rank_xendcg
// options = regression, regression_l1, huber, fair, poisson, quantile, mape, gamma, tweedie, binary, multiclass, multiclassova, cross_entropy, cross_entropy_lambda, lambdarank, rank_xendcg, survival_cox
// alias = objective_type, app, application, loss
// desc = regression application
// descl2 = ``regression``, L2 loss, aliases: ``regression_l2``, ``l2``, ``mean_squared_error``, ``mse``, ``l2_root``, ``root_mean_squared_error``, ``rmse``
Expand Down Expand Up @@ -160,6 +160,9 @@ struct Config {
// descl2 = ``rank_xendcg``, `XE_NDCG_MART <https://arxiv.org/abs/1911.09798>`__ ranking objective function, aliases: ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``
// descl2 = ``rank_xendcg`` is faster than and achieves the similar performance as ``lambdarank``
// descl2 = label should be ``int`` type, and larger number represents the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect)
// desc = survival analysis application
// descl2 = ``survival_cox``, `Cox proportional hazards <https://en.wikipedia.org/wiki/Proportional_hazards_model>`__ partial likelihood with Breslow's method for ties, aliases: ``cox``, ``cox_ph``
// descl2 = label encodes censoring via sign: positive value = event time, negative value = censored time
// desc = custom objective function (gradients and hessians not computed directly by LightGBM)
// descl2 = ``custom``
// descl2 = must be passed through parameters explicitly in the C API
Expand Down Expand Up @@ -1039,6 +1042,8 @@ struct Config {
// descl2 = ``cross_entropy``, cross-entropy (with optional linear weights), aliases: ``xentropy``
// descl2 = ``cross_entropy_lambda``, "intensity-weighted" cross-entropy, aliases: ``xentlambda``
// descl2 = ``kullback_leibler``, `Kullback-Leibler divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`__, aliases: ``kldiv``
// descl2 = ``survival_cox_nll``, negative partial log-likelihood for `Cox proportional hazards <https://en.wikipedia.org/wiki/Proportional_hazards_model>`__ model, aliases: ``cox_nll``
// descl2 = ``concordance_index``, `Harrell's concordance index <https://doi.org/10.1002/(SICI)1097-0258(19960229)15:4%3C361::AID-SIM168%3E3.0.CO;2-4>`__ for survival models, aliases: ``c_index``
// desc = support multiple metrics, separated by ``,``
std::vector<std::string> metric;

Expand Down Expand Up @@ -1293,6 +1298,8 @@ inline std::string ParseObjectiveAlias(const std::string& type) {
} else if (type == std::string("rank_xendcg") || type == std::string("xendcg") || type == std::string("xe_ndcg")
|| type == std::string("xe_ndcg_mart") || type == std::string("xendcg_mart")) {
return "rank_xendcg";
} else if (type == std::string("survival_cox") || type == std::string("cox") || type == std::string("cox_ph")) {
return "survival_cox";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
Expand Down Expand Up @@ -1323,6 +1330,10 @@ inline std::string ParseMetricAlias(const std::string& type) {
return "kullback_leibler";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape";
} else if (type == std::string("survival_cox") || type == std::string("survival_cox_nll") || type == std::string("cox") || type == std::string("cox_ph") || type == std::string("cox_nll")) {
return "survival_cox_nll";
} else if (type == std::string("c_index") || type == std::string("concordance_index")) {
return "concordance_index";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
Expand Down
3 changes: 2 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5308,5 +5308,6 @@ def __get_eval_info(self) -> None:
)
self.__name_inner_eval = [string_buffers[i].value.decode("utf-8") for i in range(self.__num_inner_eval)]
self.__higher_better_inner_eval = [
name.startswith(("auc", "ndcg@", "map@", "average_precision")) for name in self.__name_inner_eval
name.startswith(("auc", "ndcg@", "map@", "average_precision", "concordance_index"))
for name in self.__name_inner_eval
]
Loading
Loading