Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions R/model-catboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -573,16 +573,58 @@ get_catboost_tree <- function(
return(make_catboost_stump(leaf_values, class_idx, num_class))
}

# Pre-extract split info once per tree (avoids repeated lookups per leaf)
split_info <- lapply(splits, function(split) {
split_type <- split$split_type %||% "FloatFeature"
if (split_type == "OneHotFeature") {
cat_feature_index <- split$cat_feature_index + 1L
cat_feature_info <- cat_features[[cat_feature_index]]
list(
type = "categorical",
col = cat_feature_info$feature_id %||%
paste0("cat_feature_", cat_feature_info$flat_feature_index),
hash_value = split$value,
is_categorical = TRUE
)
} else {
feature_index <- split$float_feature_index + 1L
feature_info <- float_features[[feature_index]]
list(
type = "conditional",
col = feature_info$feature_id %||%
paste0("feature_", feature_info$flat_feature_index),
val = split$border,
nan_treatment = feature_info$nan_value_treatment %||% "AsIs",
is_categorical = FALSE
)
}
})

n_leaves <- 2^n_splits
map(seq_len(n_leaves) - 1L, function(leaf_idx) {
path <- lapply(seq_len(n_splits), function(split_idx) {
parse_catboost_split(
splits[[split_idx]],
leaf_idx,
split_idx,
float_features,
cat_features
)
info <- split_info[[split_idx]]
bit_val <- get_catboost_bit_value(leaf_idx, split_idx)

if (info$is_categorical) {
op <- if (bit_val == 0L) "not-equal" else "equal"
list(
type = "categorical",
col = info$col,
hash_value = info$hash_value,
op = op,
missing = FALSE
)
} else {
op <- if (bit_val == 1L) "more" else "less-equal"
list(
type = "conditional",
col = info$col,
val = info$val,
op = op,
missing = get_catboost_missing(info$nan_treatment, op)
)
}
})

list(
Expand Down
39 changes: 26 additions & 13 deletions R/model-cubist.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
parse_model.cubist <- function(model) {
coefs <- model$coefficients
splits <- model$splits
splits$variable <- as.character(splits$variable)
splits$dir <- as.character(splits$dir)
if (!is.null(splits)) {
splits$variable <- as.character(splits$variable)
splits$dir <- as.character(splits$dir)
}

# Pre-split data by committee and rule to avoid O(n) scans in nested loops
coefs_by_comm_rule <- split(coefs, list(coefs$committee, coefs$rule))
if (!is.null(splits)) {
splits_by_comm_rule <- split(splits, list(splits$committee, splits$rule))
}

committees2 <- map(
unique(coefs$committee),
Expand All @@ -12,19 +20,24 @@ parse_model.cubist <- function(model) {
rules <- map(
coefs$rule[coefs$committee == comm],
~ {
cc <- coefs[coefs$rule == .x & coefs$committee == comm, ]
key <- paste(comm, .x, sep = ".")
cc <- coefs_by_comm_rule[[key]]
if (!is.null(model$splits)) {
cs <- splits[splits$rule == .x & splits$committee == comm, ]
tcs <- transpose(cs)
mcs <- map(
tcs,
~ list(
type = "conditional",
col = .x$variable,
val = .x$value,
op = ifelse(.x$dir == ">", "more", "less-equal")
cs <- splits_by_comm_rule[[key]]
if (!is.null(cs) && nrow(cs) > 0) {
tcs <- transpose(cs)
mcs <- map(
tcs,
~ list(
type = "conditional",
col = .x$variable,
val = .x$value,
op = ifelse(.x$dir == ">", "more", "less-equal")
)
)
)
} else {
mcs <- list(list(type = "all"))
}
} else {
mcs <- list(list(type = "all"))
}
Expand Down
147 changes: 105 additions & 42 deletions R/model-lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ parse_lgb_linear_trees <- function(model, feature_names) {
}

for (line in lines) {
if (grepl("^Tree=", line)) {
if (startsWith(line, "Tree=")) {
save_tree_linear_info()
# Start new tree
current_tree <- as.integer(sub("^Tree=", "", line))
Expand All @@ -86,17 +86,17 @@ parse_lgb_linear_trees <- function(model, feature_names) {
num_features <- NULL
leaf_features <- NULL
leaf_coeff <- NULL
} else if (grepl("^is_linear=1", line)) {
} else if (startsWith(line, "is_linear=1")) {
is_linear <- TRUE
} else if (grepl("^leaf_const=", line)) {
} else if (startsWith(line, "leaf_const=")) {
leaf_const <- sub("^leaf_const=", "", line)
} else if (grepl("^num_features=", line)) {
} else if (startsWith(line, "num_features=")) {
num_features <- sub("^num_features=", "", line)
} else if (grepl("^leaf_features=", line)) {
} else if (startsWith(line, "leaf_features=")) {
leaf_features <- sub("^leaf_features=", "", line)
} else if (grepl("^leaf_coeff=", line)) {
} else if (startsWith(line, "leaf_coeff=")) {
leaf_coeff <- sub("^leaf_coeff=", "", line)
} else if (grepl("^end of trees", line)) {
} else if (startsWith(line, "end of trees")) {
save_tree_linear_info()
break
}
Expand Down Expand Up @@ -208,87 +208,141 @@ get_lgb_tree <- function(tree_df, linear_info = NULL) {
# Build children map for direction detection
children_map <- get_lgb_children_map(tree_df)

# Pre-extract columns as vectors for fast indexing
leaf_index <- tree_df$leaf_index
leaf_value <- tree_df$leaf_value
leaf_parent <- tree_df$leaf_parent
split_index <- tree_df$split_index
node_parent <- tree_df$node_parent
decision_type <- tree_df$decision_type
default_left <- tree_df$default_left == "TRUE"
split_feature <- tree_df$split_feature
threshold <- tree_df$threshold

# Build split_index to row lookup (avoid repeated which() calls)
max_split_idx <- suppressWarnings(max(split_index, na.rm = TRUE))
if (is.finite(max_split_idx)) {
split_idx_to_row <- integer(max_split_idx + 1)
for (i in seq_along(split_index)) {
si <- split_index[i]
if (!is.na(si)) {
split_idx_to_row[si + 1L] <- i
}
}
} else {
# No splits (stump tree) - empty lookup
split_idx_to_row <- integer(0)
}

# Find leaf rows
leaf_rows <- which(!is.na(tree_df$leaf_index))
leaf_rows <- which(!is.na(leaf_index))

# For each leaf, trace path to root
map(leaf_rows, function(leaf_row) {
leaf_idx <- tree_df$leaf_index[[leaf_row]]
leaf_idx <- leaf_index[leaf_row]
leaf_idx_str <- as.character(leaf_idx)
leaf_value <- tree_df$leaf_value[[leaf_row]]
leaf_val <- leaf_value[leaf_row]

# Check if this tree has linear info for this leaf
if (!is.null(linear_info) && leaf_idx_str %in% names(linear_info)) {
leaf_linear <- linear_info[[leaf_idx_str]]
# Store both linear info and fallback value (used when features are NA)
leaf_linear$fallback <- leaf_value
leaf_linear$fallback <- leaf_val
list(
prediction = NULL,
linear = leaf_linear,
path = get_lgb_path(leaf_row, tree_df, children_map)
path = get_lgb_path_fast(
leaf_row,
leaf_parent,
split_idx_to_row,
node_parent,
decision_type,
default_left,
split_feature,
threshold,
children_map
)
)
} else {
list(
prediction = leaf_value,
prediction = leaf_val,
linear = NULL,
path = get_lgb_path(leaf_row, tree_df, children_map)
path = get_lgb_path_fast(
leaf_row,
leaf_parent,
split_idx_to_row,
node_parent,
decision_type,
default_left,
split_feature,
threshold,
children_map
)
)
}
})
}

get_lgb_path <- function(leaf_row, tree_df, children_map) {
# Fast path extraction using pre-extracted vectors
get_lgb_path_fast <- function(
leaf_row,
leaf_parent,
split_idx_to_row,
node_parent,
decision_type,
default_left,
split_feature,
threshold,
children_map
) {
path <- list()
current_row <- leaf_row
current_parent_split <- tree_df$leaf_parent[[leaf_row]]
current_parent_split <- leaf_parent[leaf_row]

while (!is.na(current_parent_split)) {
# Find the parent's row (split_index should be unique within a tree)
parent_row <- which(tree_df$split_index == current_parent_split)[[1]]
# Look up parent row directly (O(1) instead of O(n))
parent_row <- split_idx_to_row[current_parent_split + 1L]

# Determine direction: is current_row the LEFT or RIGHT child?
children <- children_map[[as.character(current_parent_split)]]
is_left_child <- (current_row == children[1])

# Build condition based on decision type
decision_type <- tree_df$decision_type[[parent_row]]
default_left <- tree_df$default_left[[parent_row]] == "TRUE"
dec_type <- decision_type[parent_row]
def_left <- default_left[parent_row]

if (decision_type == "<=") {
if (dec_type == "<=") {
# Numerical split
if (is_left_child) {
op <- "less-equal"
missing_with_us <- default_left
missing_with_us <- def_left
} else {
op <- "more"
missing_with_us <- !default_left
missing_with_us <- !def_left
}

condition <- list(
type = "conditional",
col = tree_df$split_feature[[parent_row]],
val = tree_df$threshold[[parent_row]],
col = split_feature[parent_row],
val = threshold[parent_row],
op = op,
missing = missing_with_us
)
} else if (decision_type == "==") {
} else if (dec_type == "==") {
# Categorical split: threshold is "0||1||3" format
# LEFT = category IN set, RIGHT = category NOT IN set
category_set <- parse_lgb_categorical_threshold(
tree_df$threshold[[parent_row]]
)
category_set <- parse_lgb_categorical_threshold(threshold[parent_row])

if (is_left_child) {
op <- "in"
missing_with_us <- default_left
missing_with_us <- def_left
} else {
op <- "not-in"
missing_with_us <- !default_left
missing_with_us <- !def_left
}

condition <- list(
type = "set",
col = tree_df$split_feature[[parent_row]],
col = split_feature[parent_row],
vals = category_set,
op = op,
missing = missing_with_us
Expand All @@ -299,7 +353,7 @@ get_lgb_path <- function(leaf_row, tree_df, children_map) {

# Move up the tree
current_row <- parent_row
current_parent_split <- tree_df$node_parent[[parent_row]]
current_parent_split <- node_parent[parent_row]
}

rev(path) # Reverse to get root-to-leaf order
Expand Down Expand Up @@ -442,8 +496,8 @@ build_fit_formula_lgb_nested <- function(parsedmodel, model) {
))
}

# Extract nested trees
trees <- extract_lgb_trees_nested(model)
# Extract nested trees (pass feature_names to avoid redundant JSON parsing)
trees <- extract_lgb_trees_nested(model, parsedmodel$general$feature_names)

# RF boosting averages trees instead of summing
boosting <- parsedmodel$general$params$boosting
Expand All @@ -467,7 +521,7 @@ build_fit_formula_lgb_multiclass_nested <- function(
cli::cli_abort("Multiclass model must have num_class >= 2.")
}

trees <- extract_lgb_trees_nested(model)
trees <- extract_lgb_trees_nested(model, parsedmodel$general$feature_names)
apply_lgb_multiclass_transformation(trees, num_class, objective)
}

Expand Down Expand Up @@ -553,14 +607,23 @@ build_lgb_nested_condition <- function(path_elem) {
}

# Extract trees in nested format
extract_lgb_trees_nested <- function(model) {
# feature_names and linear_info can be passed to avoid redundant JSON/string parsing
extract_lgb_trees_nested <- function(
model,
feature_names = NULL,
linear_info = NULL
) {
trees_df <- lightgbm::lgb.model.dt.tree(model)
trees_df <- as.data.frame(trees_df)

# Extract linear tree info
model_json <- jsonlite::fromJSON(model$dump_model())
feature_names <- model_json$feature_names
linear_info <- parse_lgb_linear_trees(model, feature_names)
# Extract linear tree info (only if not provided)
if (is.null(feature_names)) {
model_json <- jsonlite::fromJSON(model$dump_model())
feature_names <- model_json$feature_names
}
if (is.null(linear_info)) {
linear_info <- parse_lgb_linear_trees(model, feature_names)
}

trees_split <- split(trees_df, trees_df$tree_index)

Expand Down
Loading