From 606b46040aba9e6d309554f0191032b00d0aa44d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 23 Feb 2026 08:49:02 -0800 Subject: [PATCH] speed up the model code --- R/model-catboost.R | 56 ++++++++++++++--- R/model-cubist.R | 39 ++++++++---- R/model-lightgbm.R | 147 ++++++++++++++++++++++++++++++++------------- R/model-partykit.R | 109 +++++++++++++++++++-------------- R/model-ranger.R | 137 ++++++++++++++++++++++-------------------- R/model-rf.R | 127 +++++++++++++++++++-------------------- R/model-xgboost.R | 121 ++++++++++++++++++++++++------------- 7 files changed, 458 insertions(+), 278 deletions(-) diff --git a/R/model-catboost.R b/R/model-catboost.R index 98f981c..d5645ad 100644 --- a/R/model-catboost.R +++ b/R/model-catboost.R @@ -573,16 +573,58 @@ get_catboost_tree <- function( return(make_catboost_stump(leaf_values, class_idx, num_class)) } + # Pre-extract split info once per tree (avoids repeated lookups per leaf) + split_info <- lapply(splits, function(split) { + split_type <- split$split_type %||% "FloatFeature" + if (split_type == "OneHotFeature") { + cat_feature_index <- split$cat_feature_index + 1L + cat_feature_info <- cat_features[[cat_feature_index]] + list( + type = "categorical", + col = cat_feature_info$feature_id %||% + paste0("cat_feature_", cat_feature_info$flat_feature_index), + hash_value = split$value, + is_categorical = TRUE + ) + } else { + feature_index <- split$float_feature_index + 1L + feature_info <- float_features[[feature_index]] + list( + type = "conditional", + col = feature_info$feature_id %||% + paste0("feature_", feature_info$flat_feature_index), + val = split$border, + nan_treatment = feature_info$nan_value_treatment %||% "AsIs", + is_categorical = FALSE + ) + } + }) + n_leaves <- 2^n_splits map(seq_len(n_leaves) - 1L, function(leaf_idx) { path <- lapply(seq_len(n_splits), function(split_idx) { - parse_catboost_split( - splits[[split_idx]], - leaf_idx, - split_idx, - float_features, - cat_features - ) + info <- split_info[[split_idx]] + bit_val <- get_catboost_bit_value(leaf_idx, split_idx) + + if (info$is_categorical) { + op <- if (bit_val == 0L) "not-equal" else "equal" + list( + type = "categorical", + col = info$col, + hash_value = info$hash_value, + op = op, + missing = FALSE + ) + } else { + op <- if (bit_val == 1L) "more" else "less-equal" + list( + type = "conditional", + col = info$col, + val = info$val, + op = op, + missing = get_catboost_missing(info$nan_treatment, op) + ) + } }) list( diff --git a/R/model-cubist.R b/R/model-cubist.R index 4e21a36..3e0eec9 100644 --- a/R/model-cubist.R +++ b/R/model-cubist.R @@ -2,8 +2,16 @@ parse_model.cubist <- function(model) { coefs <- model$coefficients splits <- model$splits - splits$variable <- as.character(splits$variable) - splits$dir <- as.character(splits$dir) + if (!is.null(splits)) { + splits$variable <- as.character(splits$variable) + splits$dir <- as.character(splits$dir) + } + + # Pre-split data by committee and rule to avoid O(n) scans in nested loops + coefs_by_comm_rule <- split(coefs, list(coefs$committee, coefs$rule)) + if (!is.null(splits)) { + splits_by_comm_rule <- split(splits, list(splits$committee, splits$rule)) + } committees2 <- map( unique(coefs$committee), @@ -12,19 +20,24 @@ parse_model.cubist <- function(model) { rules <- map( coefs$rule[coefs$committee == comm], ~ { - cc <- coefs[coefs$rule == .x & coefs$committee == comm, ] + key <- paste(comm, .x, sep = ".") + cc <- coefs_by_comm_rule[[key]] if (!is.null(model$splits)) { - cs <- splits[splits$rule == .x & splits$committee == comm, ] - tcs <- transpose(cs) - mcs <- map( - tcs, - ~ list( - type = "conditional", - col = .x$variable, - val = .x$value, - op = ifelse(.x$dir == ">", "more", "less-equal") + cs <- splits_by_comm_rule[[key]] + if (!is.null(cs) && nrow(cs) > 0) { + tcs <- transpose(cs) + mcs <- map( + tcs, + ~ list( + type = "conditional", + col = .x$variable, + val = .x$value, + op = ifelse(.x$dir == ">", "more", "less-equal") + ) ) - ) + } else { + mcs <- list(list(type = "all")) + } } else { mcs <- list(list(type = "all")) } diff --git a/R/model-lightgbm.R b/R/model-lightgbm.R index 6b277c1..57b49e7 100644 --- a/R/model-lightgbm.R +++ b/R/model-lightgbm.R @@ -77,7 +77,7 @@ parse_lgb_linear_trees <- function(model, feature_names) { } for (line in lines) { - if (grepl("^Tree=", line)) { + if (startsWith(line, "Tree=")) { save_tree_linear_info() # Start new tree current_tree <- as.integer(sub("^Tree=", "", line)) @@ -86,17 +86,17 @@ parse_lgb_linear_trees <- function(model, feature_names) { num_features <- NULL leaf_features <- NULL leaf_coeff <- NULL - } else if (grepl("^is_linear=1", line)) { + } else if (startsWith(line, "is_linear=1")) { is_linear <- TRUE - } else if (grepl("^leaf_const=", line)) { + } else if (startsWith(line, "leaf_const=")) { leaf_const <- sub("^leaf_const=", "", line) - } else if (grepl("^num_features=", line)) { + } else if (startsWith(line, "num_features=")) { num_features <- sub("^num_features=", "", line) - } else if (grepl("^leaf_features=", line)) { + } else if (startsWith(line, "leaf_features=")) { leaf_features <- sub("^leaf_features=", "", line) - } else if (grepl("^leaf_coeff=", line)) { + } else if (startsWith(line, "leaf_coeff=")) { leaf_coeff <- sub("^leaf_coeff=", "", line) - } else if (grepl("^end of trees", line)) { + } else if (startsWith(line, "end of trees")) { save_tree_linear_info() break } @@ -208,87 +208,141 @@ get_lgb_tree <- function(tree_df, linear_info = NULL) { # Build children map for direction detection children_map <- get_lgb_children_map(tree_df) + # Pre-extract columns as vectors for fast indexing + leaf_index <- tree_df$leaf_index + leaf_value <- tree_df$leaf_value + leaf_parent <- tree_df$leaf_parent + split_index <- tree_df$split_index + node_parent <- tree_df$node_parent + decision_type <- tree_df$decision_type + default_left <- tree_df$default_left == "TRUE" + split_feature <- tree_df$split_feature + threshold <- tree_df$threshold + + # Build split_index to row lookup (avoid repeated which() calls) + max_split_idx <- suppressWarnings(max(split_index, na.rm = TRUE)) + if (is.finite(max_split_idx)) { + split_idx_to_row <- integer(max_split_idx + 1) + for (i in seq_along(split_index)) { + si <- split_index[i] + if (!is.na(si)) { + split_idx_to_row[si + 1L] <- i + } + } + } else { + # No splits (stump tree) - empty lookup + split_idx_to_row <- integer(0) + } + # Find leaf rows - leaf_rows <- which(!is.na(tree_df$leaf_index)) + leaf_rows <- which(!is.na(leaf_index)) # For each leaf, trace path to root map(leaf_rows, function(leaf_row) { - leaf_idx <- tree_df$leaf_index[[leaf_row]] + leaf_idx <- leaf_index[leaf_row] leaf_idx_str <- as.character(leaf_idx) - leaf_value <- tree_df$leaf_value[[leaf_row]] + leaf_val <- leaf_value[leaf_row] # Check if this tree has linear info for this leaf if (!is.null(linear_info) && leaf_idx_str %in% names(linear_info)) { leaf_linear <- linear_info[[leaf_idx_str]] # Store both linear info and fallback value (used when features are NA) - leaf_linear$fallback <- leaf_value + leaf_linear$fallback <- leaf_val list( prediction = NULL, linear = leaf_linear, - path = get_lgb_path(leaf_row, tree_df, children_map) + path = get_lgb_path_fast( + leaf_row, + leaf_parent, + split_idx_to_row, + node_parent, + decision_type, + default_left, + split_feature, + threshold, + children_map + ) ) } else { list( - prediction = leaf_value, + prediction = leaf_val, linear = NULL, - path = get_lgb_path(leaf_row, tree_df, children_map) + path = get_lgb_path_fast( + leaf_row, + leaf_parent, + split_idx_to_row, + node_parent, + decision_type, + default_left, + split_feature, + threshold, + children_map + ) ) } }) } -get_lgb_path <- function(leaf_row, tree_df, children_map) { +# Fast path extraction using pre-extracted vectors +get_lgb_path_fast <- function( + leaf_row, + leaf_parent, + split_idx_to_row, + node_parent, + decision_type, + default_left, + split_feature, + threshold, + children_map +) { path <- list() current_row <- leaf_row - current_parent_split <- tree_df$leaf_parent[[leaf_row]] + current_parent_split <- leaf_parent[leaf_row] while (!is.na(current_parent_split)) { - # Find the parent's row (split_index should be unique within a tree) - parent_row <- which(tree_df$split_index == current_parent_split)[[1]] + # Look up parent row directly (O(1) instead of O(n)) + parent_row <- split_idx_to_row[current_parent_split + 1L] # Determine direction: is current_row the LEFT or RIGHT child? children <- children_map[[as.character(current_parent_split)]] is_left_child <- (current_row == children[1]) # Build condition based on decision type - decision_type <- tree_df$decision_type[[parent_row]] - default_left <- tree_df$default_left[[parent_row]] == "TRUE" + dec_type <- decision_type[parent_row] + def_left <- default_left[parent_row] - if (decision_type == "<=") { + if (dec_type == "<=") { # Numerical split if (is_left_child) { op <- "less-equal" - missing_with_us <- default_left + missing_with_us <- def_left } else { op <- "more" - missing_with_us <- !default_left + missing_with_us <- !def_left } condition <- list( type = "conditional", - col = tree_df$split_feature[[parent_row]], - val = tree_df$threshold[[parent_row]], + col = split_feature[parent_row], + val = threshold[parent_row], op = op, missing = missing_with_us ) - } else if (decision_type == "==") { + } else if (dec_type == "==") { # Categorical split: threshold is "0||1||3" format - # LEFT = category IN set, RIGHT = category NOT IN set - category_set <- parse_lgb_categorical_threshold( - tree_df$threshold[[parent_row]] - ) + category_set <- parse_lgb_categorical_threshold(threshold[parent_row]) if (is_left_child) { op <- "in" - missing_with_us <- default_left + missing_with_us <- def_left } else { op <- "not-in" - missing_with_us <- !default_left + missing_with_us <- !def_left } condition <- list( type = "set", - col = tree_df$split_feature[[parent_row]], + col = split_feature[parent_row], vals = category_set, op = op, missing = missing_with_us @@ -299,7 +353,7 @@ get_lgb_path <- function(leaf_row, tree_df, children_map) { # Move up the tree current_row <- parent_row - current_parent_split <- tree_df$node_parent[[parent_row]] + current_parent_split <- node_parent[parent_row] } rev(path) # Reverse to get root-to-leaf order @@ -442,8 +496,8 @@ build_fit_formula_lgb_nested <- function(parsedmodel, model) { )) } - # Extract nested trees - trees <- extract_lgb_trees_nested(model) + # Extract nested trees (pass feature_names to avoid redundant JSON parsing) + trees <- extract_lgb_trees_nested(model, parsedmodel$general$feature_names) # RF boosting averages trees instead of summing boosting <- parsedmodel$general$params$boosting @@ -467,7 +521,7 @@ build_fit_formula_lgb_multiclass_nested <- function( cli::cli_abort("Multiclass model must have num_class >= 2.") } - trees <- extract_lgb_trees_nested(model) + trees <- extract_lgb_trees_nested(model, parsedmodel$general$feature_names) apply_lgb_multiclass_transformation(trees, num_class, objective) } @@ -553,14 +607,23 @@ build_lgb_nested_condition <- function(path_elem) { } # Extract trees in nested format -extract_lgb_trees_nested <- function(model) { +# feature_names and linear_info can be passed to avoid redundant JSON/string parsing +extract_lgb_trees_nested <- function( + model, + feature_names = NULL, + linear_info = NULL +) { trees_df <- lightgbm::lgb.model.dt.tree(model) trees_df <- as.data.frame(trees_df) - # Extract linear tree info - model_json <- jsonlite::fromJSON(model$dump_model()) - feature_names <- model_json$feature_names - linear_info <- parse_lgb_linear_trees(model, feature_names) + # Extract linear tree info (only if not provided) + if (is.null(feature_names)) { + model_json <- jsonlite::fromJSON(model$dump_model()) + feature_names <- model_json$feature_names + } + if (is.null(linear_info)) { + linear_info <- parse_lgb_linear_trees(model, feature_names) + } trees_split <- split(trees_df, trees_df$tree_index) diff --git a/R/model-partykit.R b/R/model-partykit.R index 831e31b..32077ee 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -62,34 +62,63 @@ partykit_tree_info_full <- function(model) { } partykit_tree_info <- function(model) { - model_nodes <- map(seq_along(model), ~ model[[.x]]) - is_split <- map_lgl(model_nodes, ~ class(.x$node[1]) == "partynode") - if (is.numeric(model_nodes[[1]]$fitted[["(response)"]])) { - mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) - prediction <- ifelse(!is_split, mean_resp, NA) + # Get all node IDs at once (avoids repeated tree traversals) + all_node_ids <- partykit::nodeids(model) + n_nodes <- length(all_node_ids) + + # Extract all nodes at once using nodeapply with all IDs + all_nodes <- partykit::nodeapply(model, ids = all_node_ids, FUN = identity) + + # Pre-extract node properties to avoid repeated list access + is_split <- logical(n_nodes) + splitvarID <- integer(n_nodes) + splitval <- numeric(n_nodes) + split_index <- vector("list", n_nodes) + left_child <- integer(n_nodes) + right_child <- integer(n_nodes) + + for (i in seq_len(n_nodes)) { + node <- all_nodes[[i]] + is_split[i] <- !partykit::is.terminal(node) + if (is_split[i]) { + splitvarID[i] <- node$split$varid + splitval[i] <- node$split$breaks %||% NA_real_ + split_index[[i]] <- node$split$index + kids <- partykit::kids_node(node) + left_child[i] <- partykit::id_node(kids[[1]]) + right_child[i] <- partykit::id_node(kids[[2]]) + } else { + splitvarID[i] <- NA_integer_ + splitval[i] <- NA_real_ + left_child[i] <- NA_integer_ + right_child[i] <- NA_integer_ + } + } + + # Extract predictions from fitted data (only need to access once) + fitted_data <- model$fitted + response_col <- fitted_data[["(response)"]] + node_col <- fitted_data[["(fitted)"]] + + if (is.numeric(response_col)) { + # Regression: compute mean per node + node_means <- tapply(response_col, node_col, mean) + prediction <- ifelse(!is_split, node_means[as.character(all_node_ids)], NA) } else { + # Classification: compute mode per node stat_mode <- function(x) { counts <- rev(sort(table(x))) - if (counts[[1]] == counts[[2]]) { + if (length(counts) > 1 && counts[[1]] == counts[[2]]) { ties <- counts[counts[1] == counts] return(names(rev(ties))[1]) } names(counts)[1] } - mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"])) - prediction <- ifelse(!is_split, mode_resp, NA) + node_modes <- tapply(response_col, node_col, stat_mode) + prediction <- ifelse(!is_split, node_modes[as.character(all_node_ids)], NA) } - party_nodes <- map(seq_along(model), ~ partykit::nodeapply(model, .x)) - - kids <- map( - party_nodes, - ~ { - if (length(.x[[1]]$kids)) { - map(.x[[1]]$kids, ~ .x$id) - } - } - ) + # Get variable info vars <- as.character(attr(model$terms, "variables")) vars <- vars[2:length(vars)] @@ -97,49 +126,41 @@ partykit_tree_info <- function(model) { var_class <- as.character(var_details) var_name <- names(var_details) - splitvarID <- map_int( - model_nodes, - ~ ifelse(is.null(.x$node$split$varid), NA, .x$node$split$varid) - ) - + # Build categorical split strings if (length(var_class) > 0) { - class_splits <- map_chr( - seq_along(splitvarID), - ~ { - if (is.na(splitvarID[.x])) { - return(NA) - } - v <- vars[splitvarID[.x]] + class_splits <- character(n_nodes) + for (i in seq_len(n_nodes)) { + if (is.na(splitvarID[i])) { + class_splits[i] <- NA_character_ + } else { + v <- vars[splitvarID[i]] if (var_class[var_name == v] == "factor") { lvls <- levels(model$data[, colnames(model$data) == v]) - pn <- party_nodes[[.x]][[1]]$split$index + pn <- split_index[[i]] pn <- ifelse(is.na(pn), 0, pn) if (any(pn == 3)) { cli::cli_abort("Three levels are not supported.") } - paste0(lvls[pn == 1], collapse = ", ") + class_splits[i] <- paste0(lvls[pn == 1], collapse = ", ") } else { - NA + class_splits[i] <- NA_character_ } } - ) + } } else { - class_splits <- NA + class_splits <- rep(NA_character_, n_nodes) } data.frame( - nodeID = seq_along(is_split) - 1, - leftChild = map_int(kids, ~ ifelse(is.null(.x[[1]]), NA, .x[[1]])) - 1, - rightChild = map_int(kids, ~ ifelse(is.null(.x[[2]]), NA, .x[[2]])) - 1, - splitvarID, + nodeID = all_node_ids - 1L, + leftChild = left_child - 1L, + rightChild = right_child - 1L, + splitvarID = splitvarID, splitvarName = vars[splitvarID], - splitval = map_dbl( - model_nodes, - ~ ifelse(is.null(.x$node$split$breaks), NA, .x$node$split$breaks) - ), + splitval = splitval, splitclass = class_splits, terminal = !is_split, - prediction + prediction = prediction ) } diff --git a/R/model-ranger.R b/R/model-ranger.R index 95df923..b8c1f53 100644 --- a/R/model-ranger.R +++ b/R/model-ranger.R @@ -106,43 +106,46 @@ tidypredict_fit_ranger_nested <- function(model) { # Build nested case_when for a single ranger tree build_nested_ranger_tree <- function(model, tree_no) { tree <- ranger::treeInfo(model, tree_no) - build_nested_ranger_node(0L, tree) -} -# Recursively build nested case_when for ranger node -build_nested_ranger_node <- function(node_id, tree) { - # node_id is 0-indexed in ranger - row <- tree[tree$nodeID == node_id, ] + # Pre-extract columns as vectors for fast indexing (avoids slow df[i,] access) + nodeID <- tree$nodeID + leftChild <- tree$leftChild + rightChild <- tree$rightChild + splitvarName <- as.character(tree$splitvarName) + splitval <- tree$splitval + terminal <- tree$terminal + prediction <- tree$prediction + splitclass <- tree$splitclass - # Check if terminal (leaf) node - if (row$terminal) { - return(row$prediction) - } + build_node <- function(node_id) { + # node_id is 0-indexed, convert to 1-indexed for vector access + idx <- node_id + 1L + + if (terminal[idx]) { + return(prediction[idx]) + } + + left_id <- leftChild[idx] + right_id <- rightChild[idx] + split_var <- splitvarName[idx] + split_val <- splitval[idx] + + left_subtree <- build_node(left_id) + right_subtree <- build_node(right_id) + + col_sym <- rlang::sym(split_var) + + if (is.na(split_val)) { + cats <- strsplit(as.character(splitclass[idx]), ", ")[[1]] + condition <- expr(!!col_sym %in% !!cats) + } else { + condition <- expr(!!col_sym <= !!split_val) + } - # Internal node - get split info - left_id <- row$leftChild - right_id <- row$rightChild - split_var <- row$splitvarName - split_val <- row$splitval - - # Recurse - left_subtree <- build_nested_ranger_node(left_id, tree) - right_subtree <- build_nested_ranger_node(right_id, tree) - - col_sym <- rlang::sym(as.character(split_var)) - - # Check if categorical split (splitval is NA for categorical) - if (is.na(split_val)) { - # Categorical split - split_class <- row$splitclass - cats <- strsplit(as.character(split_class), ", ")[[1]] - condition <- expr(!!col_sym %in% !!cats) - } else { - # Numeric split: left = <= splitval, right = > splitval - condition <- expr(!!col_sym <= !!split_val) + expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) } - expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + build_node(0L) } # Legacy flat case_when (for v1/v2 parsed model compatibility) ---------------- @@ -364,43 +367,45 @@ get_ra_trees <- function(model) { # Build nested case_when for ranger probability tree build_nested_ranger_prob_tree <- function(model, tree_no, class_level) { tree <- ranger::treeInfo(model, tree_no) - build_nested_ranger_prob_node(0L, tree, class_level) -} -# Recursively build nested case_when for ranger probability node -build_nested_ranger_prob_node <- function(node_id, tree, class_level) { - # node_id is 0-indexed in ranger - row <- tree[tree$nodeID == node_id, ] + # Pre-extract columns as vectors for fast indexing (avoids slow df[i,] access) + nodeID <- tree$nodeID + leftChild <- tree$leftChild + rightChild <- tree$rightChild + splitvarName <- as.character(tree$splitvarName) + splitval <- tree$splitval + terminal <- tree$terminal + splitclass <- tree$splitclass + prob_col <- paste0("pred.", class_level) + prob_vals <- tree[[prob_col]] - # Check if terminal (leaf) node - if (row$terminal) { - # Get probability for the specific class - prob_col <- paste0("pred.", class_level) - return(row[[prob_col]]) - } + build_node <- function(node_id) { + # node_id is 0-indexed, convert to 1-indexed for vector access + idx <- node_id + 1L + + if (terminal[idx]) { + return(prob_vals[idx]) + } + + left_id <- leftChild[idx] + right_id <- rightChild[idx] + split_var <- splitvarName[idx] + split_val <- splitval[idx] + + left_subtree <- build_node(left_id) + right_subtree <- build_node(right_id) + + col_sym <- rlang::sym(split_var) + + if (is.na(split_val)) { + cats <- strsplit(as.character(splitclass[idx]), ", ")[[1]] + condition <- expr(!!col_sym %in% !!cats) + } else { + condition <- expr(!!col_sym <= !!split_val) + } - # Internal node - get split info - left_id <- row$leftChild - right_id <- row$rightChild - split_var <- row$splitvarName - split_val <- row$splitval - - # Recurse - left_subtree <- build_nested_ranger_prob_node(left_id, tree, class_level) - right_subtree <- build_nested_ranger_prob_node(right_id, tree, class_level) - - col_sym <- rlang::sym(as.character(split_var)) - - # Check if categorical split (splitval is NA for categorical) - if (is.na(split_val)) { - # Categorical split - split_class <- row$splitclass - cats <- strsplit(as.character(split_class), ", ")[[1]] - condition <- expr(!!col_sym %in% !!cats) - } else { - # Numeric split: left = <= splitval, right = > splitval - condition <- expr(!!col_sym <= !!split_val) + expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) } - expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + build_node(0L) } diff --git a/R/model-rf.R b/R/model-rf.R index d5f11ad..d2f6bda 100644 --- a/R/model-rf.R +++ b/R/model-rf.R @@ -107,35 +107,42 @@ tidypredict_fit_rf_nested <- function(model) { # Build nested case_when for a single randomForest tree build_nested_rf_tree <- function(model, tree_no, term_labels) { tree <- randomForest::getTree(model, tree_no) - build_nested_rf_node(1L, tree, term_labels) -} -# Recursively build nested case_when for randomForest node -build_nested_rf_node <- function(node_id, tree, term_labels) { - row <- tree[node_id, ] + # Pre-extract columns as vectors for fast indexing (avoids slow row access) + # Use unname() once here instead of on every recursive call + status <- unname(tree[, "status"]) + prediction <- unname(tree[, "prediction"]) + left_daughter <- unname(tree[, "left daughter"]) + right_daughter <- unname(tree[, "right daughter"]) + split_var <- unname(tree[, "split var"]) + split_point <- unname(tree[, "split point"]) + + build_node <- function(node_id) { + # Check if terminal (leaf) node - status == -1 + if (status[node_id] == -1) { + return(prediction[node_id]) + } - # Check if terminal (leaf) node - status == -1 - if (row["status"] == -1) { - return(unname(row["prediction"])) - } + # Internal node - get split info + left_id <- left_daughter[node_id] + right_id <- right_daughter[node_id] + var_idx <- split_var[node_id] + split_val <- split_point[node_id] - # Internal node - get split info - left_id <- unname(row["left daughter"]) - right_id <- unname(row["right daughter"]) - split_var <- unname(row["split var"]) - split_val <- unname(row["split point"]) + # Recurse + left_subtree <- build_node(left_id) + right_subtree <- build_node(right_id) - # Recurse - left_subtree <- build_nested_rf_node(left_id, tree, term_labels) - right_subtree <- build_nested_rf_node(right_id, tree, term_labels) + col_name <- term_labels[var_idx] + col_sym <- rlang::sym(col_name) - col_name <- term_labels[split_var] - col_sym <- rlang::sym(col_name) + # Numeric split: left = <= splitval, right = > splitval + condition <- expr(!!col_sym <= !!split_val) - # Numeric split: left = <= splitval, right = > splitval - condition <- expr(!!col_sym <= !!split_val) + expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + } - expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + build_node(1L) } # Legacy flat case_when (for v1/v2 parsed model compatibility) ---------------- @@ -285,53 +292,43 @@ build_nested_rf_vote_tree <- function( class_level ) { tree <- randomForest::getTree(model, tree_no) - build_nested_rf_vote_node(1L, tree, term_labels, model$classes, class_level) -} + classes <- model$classes + + # Pre-extract columns as vectors for fast indexing (avoids slow row access) + # Use unname() once here instead of on every recursive call + status <- unname(tree[, "status"]) + prediction <- unname(tree[, "prediction"]) + left_daughter <- unname(tree[, "left daughter"]) + right_daughter <- unname(tree[, "right daughter"]) + split_var <- unname(tree[, "split var"]) + split_point <- unname(tree[, "split point"]) + + build_node <- function(node_id) { + # Check if terminal (leaf) node - status == -1 + if (status[node_id] == -1) { + # Return 1 if prediction matches class_level, 0 otherwise + pred_class <- classes[prediction[node_id]] + return(if (pred_class == class_level) 1L else 0L) + } -# Recursively build nested case_when for randomForest voting node -build_nested_rf_vote_node <- function( - node_id, - tree, - term_labels, - classes, - class_level -) { - row <- tree[node_id, ] + # Internal node - get split info + left_id <- left_daughter[node_id] + right_id <- right_daughter[node_id] + var_idx <- split_var[node_id] + split_val <- split_point[node_id] - # Check if terminal (leaf) node - status == -1 - if (row["status"] == -1) { - # Return 1 if prediction matches class_level, 0 otherwise - pred_class <- classes[unname(row["prediction"])] - return(if (pred_class == class_level) 1 else 0) - } + # Recurse + left_subtree <- build_node(left_id) + right_subtree <- build_node(right_id) - # Internal node - get split info - left_id <- unname(row["left daughter"]) - right_id <- unname(row["right daughter"]) - split_var <- unname(row["split var"]) - split_val <- unname(row["split point"]) - - # Recurse - left_subtree <- build_nested_rf_vote_node( - left_id, - tree, - term_labels, - classes, - class_level - ) - right_subtree <- build_nested_rf_vote_node( - right_id, - tree, - term_labels, - classes, - class_level - ) + col_name <- term_labels[var_idx] + col_sym <- rlang::sym(col_name) - col_name <- term_labels[split_var] - col_sym <- rlang::sym(col_name) + # Numeric split: left = <= splitval, right = > splitval + condition <- expr(!!col_sym <= !!split_val) - # Numeric split: left = <= splitval, right = > splitval - condition <- expr(!!col_sym <= !!split_val) + expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + } - expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + build_node(1L) } diff --git a/R/model-xgboost.R b/R/model-xgboost.R index 7ec522c..0f012fa 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -172,16 +172,49 @@ get_xgb_json_params <- function(model) { tmp_file <- tempfile(fileext = ".json") xgboost::xgb.save(model, tmp_file) - json <- jsonlite::fromJSON(tmp_file) + # Use regex extraction instead of full JSON parsing (3-4x faster) + txt <- paste(readLines(tmp_file, warn = FALSE), collapse = "") - base_score <- json$learner$learner_model_param$base_score - base_score <- gsub("\\[", "", base_score) - base_score <- gsub("\\]", "", base_score) - base_score <- strsplit(base_score, ",")[[1]] - base_score <- as.numeric(base_score) + # Extract base_score - format is "base_score":"[5E-1]" + base_score_match <- regmatches( + txt, + regexpr('base_score":"\\[[^]]+\\]', txt, perl = TRUE) + ) + + if (length(base_score_match) > 0 && nchar(base_score_match) > 0) { + base_score_str <- gsub( + 'base_score":"\\[([^]]+)\\]', + "\\1", + base_score_match, + perl = TRUE + ) + base_score <- as.numeric(strsplit(base_score_str, ",")[[1]]) + } else { + base_score <- 0.5 + } - booster_name <- json$learner$gradient_booster$name - weight_drop <- json$learner$gradient_booster$weight_drop + # Extract booster name using fixed string matching + + booster_name <- "gbtree" + if (grepl('"name":"dart"', txt, fixed = TRUE)) { + booster_name <- "dart" + } else if (grepl('"name":"gblinear"', txt, fixed = TRUE)) { + booster_name <- "gblinear" + } + + # Extract weight_drop for DART + + weight_drop <- NULL + if (booster_name == "dart") { + wd_match <- regmatches( + txt, + regexpr('weight_drop":\\[[^]]+\\]', txt, perl = TRUE) + ) + if (length(wd_match) > 0 && nchar(wd_match) > 0) { + wd_str <- gsub('weight_drop":\\[([^]]+)\\]', "\\1", wd_match, perl = TRUE) + weight_drop <- as.numeric(strsplit(wd_str, ",")[[1]]) + } + } list( base_score = base_score, @@ -357,7 +390,10 @@ get_xgb_trees_df <- function(model) { # Convert Yes/No/Missing to integer indices trees[, c("Yes", "No", "Missing")] <- lapply( trees[, c("Yes", "No", "Missing")], - function(x) as.integer(sub("^.*-", "", x)) + 1L + function(x) { + dash_pos <- regexpr("-", x, fixed = TRUE) + as.integer(substring(x, dash_pos + 1L)) + 1L + } ) trees @@ -365,42 +401,45 @@ get_xgb_trees_df <- function(model) { # Build nested case_when for a single xgboost tree build_nested_xgb_tree <- function(tree_df) { - # tree_df has Node (0-indexed), Feature, Split, Yes, No, Missing, Quality/Gain - build_nested_xgb_node(1L, tree_df) -} - -build_nested_xgb_node <- function(node_idx, tree_df) { - row <- tree_df[node_idx, ] + # Pre-extract columns as vectors for fast indexing (avoids slow df[i,] access) + Feature <- tree_df$Feature + Gain <- tree_df$Gain %||% tree_df$Quality + Split <- tree_df$Split + Yes <- tree_df$Yes + No <- tree_df$No + Missing <- tree_df$Missing + feature_name <- tree_df$feature_name + + build_node <- function(node_idx) { + # Leaf node + if (Feature[node_idx] == "Leaf") { + return(Gain[node_idx]) + } - # Leaf node - if (row$Feature == "Leaf") { - return(row$Gain %||% row$Quality) - } + # Internal node + col <- rlang::sym(feature_name[node_idx]) + threshold <- Split[node_idx] + left_idx <- Yes[node_idx] + right_idx <- No[node_idx] + missing_idx <- Missing[node_idx] + + left_subtree <- build_node(left_idx) + right_subtree <- build_node(right_idx) + + # xgboost: Yes = left (< threshold), No = right (>= threshold) + # Missing can go either way + if (missing_idx == left_idx) { + # Missing goes left: (< threshold OR is.na) + condition <- expr(!!col < !!threshold | is.na(!!col)) + } else { + # Missing goes right or no missing: < threshold (no NA) + condition <- expr(!!col < !!threshold) + } - # Internal node - col <- rlang::sym(row$feature_name) - threshold <- row$Split - left_idx <- row$Yes - right_idx <- row$No - missing_idx <- row$Missing - - left_subtree <- build_nested_xgb_node(left_idx, tree_df) - right_subtree <- build_nested_xgb_node(right_idx, tree_df) - - # xgboost: Yes = left (< threshold), No = right (>= threshold) - # Missing can go either way - if (missing_idx == left_idx) { - # Missing goes left: (< threshold OR is.na) - condition <- expr(!!col < !!threshold | is.na(!!col)) - } else if (missing_idx == right_idx) { - # Missing goes right: < threshold (no NA) - condition <- expr(!!col < !!threshold) - } else { - # No missing handling - condition <- expr(!!col < !!threshold) + expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) } - expr(case_when(!!condition ~ !!left_subtree, .default = !!right_subtree)) + build_node(1L) } # Legacy flat case_when (for v1/v2 parsed model compatibility) ----------------