diff --git a/DESCRIPTION b/DESCRIPTION index 0d91231..3f1c093 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -10,7 +10,7 @@ Description: It parses a fitted 'R' model object, and returns a formula in 'Tidy Eval' code that calculates the predictions. It works with several databases back-ends because it leverages 'dplyr' and 'dbplyr' for the final 'SQL' translation of the algorithm. It currently - supports lm(), glm(), randomForest(), ranger(), earth(), + supports lm(), glm(), randomForest(), ranger(), rpart(), earth(), xgb.Booster.complete(), lgb.Booster(), catboost.Model(), cubist(), and ctree() models. License: MIT + file LICENSE @@ -46,6 +46,7 @@ Suggests: partykit, randomForest, ranger, + rpart, rmarkdown, RSQLite, survival, diff --git a/NAMESPACE b/NAMESPACE index 4ce712e..e78eda9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,6 +15,7 @@ S3method(parse_model,model_fit) S3method(parse_model,party) S3method(parse_model,randomForest) S3method(parse_model,ranger) +S3method(parse_model,rpart) S3method(parse_model,xgb.Booster) S3method(print,tidypredict_test) S3method(tidy,pm_regression) @@ -35,6 +36,7 @@ S3method(tidypredict_fit,pm_tree) S3method(tidypredict_fit,pm_xgb) S3method(tidypredict_fit,randomForest) S3method(tidypredict_fit,ranger) +S3method(tidypredict_fit,rpart) S3method(tidypredict_fit,xgb.Booster) S3method(tidypredict_interval,data.frame) S3method(tidypredict_interval,glm) @@ -47,6 +49,7 @@ S3method(tidypredict_test,glmnet) S3method(tidypredict_test,lgb.Booster) S3method(tidypredict_test,model_fit) S3method(tidypredict_test,party) +S3method(tidypredict_test,rpart) S3method(tidypredict_test,xgb.Booster) export(.build_case_when_tree) export(.build_linear_pred) @@ -57,6 +60,7 @@ export(.extract_lgb_trees) export(.extract_partykit_classprob) export(.extract_ranger_classprob) export(.extract_rf_classprob) +export(.extract_rpart_classprob) export(.extract_xgb_trees) export(acceptable_formula) export(as_parsed_model) diff --git a/NEWS.md b/NEWS.md index 9f8b768..b5dfc53 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ ## New Model Supports +- Added support for rpart decision tree models (`rpart`). (#226) + - Added support for CatBoost models (`catboost.Model`). (#TBD, #187, #188) - Objectives: RMSE, MAE, Quantile, MAPE, Poisson, Huber, LogCosh, Expectile, Tweedie, Logloss, CrossEntropy, MultiClass, and MultiClassOneVsAll. - Tree types: oblivious (default `SymmetricTree`) and non-oblivious (`Depthwise` or `Lossguide` grow policy). diff --git a/R/model-rpart.R b/R/model-rpart.R new file mode 100644 index 0000000..cb846fd --- /dev/null +++ b/R/model-rpart.R @@ -0,0 +1,536 @@ +# Extract comprehensive tree info including surrogate splits +rpart_tree_info_full <- function(model) { + frame <- model$frame + splits <- model$splits + orig_node_ids <- as.integer(rownames(frame)) + is_terminal <- frame$var == "" + + # Check if surrogates are used during prediction + # usesurrogate: 0 = don't use (NAs stop at internal nodes) + # 1 = use surrogates, if all NA stop at internal node + # 2 = use surrogates, if all NA go to majority (default) + usesurr <- model$control$usesurrogate + use_surrogates <- !identical(usesurr, 0L) && !identical(usesurr, 0) + + # Create mapping from original rpart node IDs to sequential 0-indexed IDs + id_map <- setNames(seq_along(orig_node_ids) - 1L, orig_node_ids) + + # Build child relationships + left_candidates <- 2L * orig_node_ids + right_candidates <- 2L * orig_node_ids + 1L + + rpart_left <- ifelse( + left_candidates %in% orig_node_ids, + id_map[as.character(left_candidates)], + NA_integer_ + ) + rpart_right <- ifelse( + right_candidates %in% orig_node_ids, + id_map[as.character(right_candidates)], + NA_integer_ + ) + + # Get predictions for ALL nodes (needed for usesurrogate=0) + if (model$method == "class") { + ylevels <- attr(model, "ylevels") + prediction <- ylevels[frame$yval] + } else { + prediction <- frame$yval + } + + # Extract split info per node, including surrogates + n_nodes <- nrow(frame) + node_splits <- vector("list", n_nodes) + needs_swap <- rep(FALSE, n_nodes) + majority_left <- rep(NA, n_nodes) + + if (!is.null(splits) && nrow(splits) > 0) { + split_idx <- 1 + for (i in seq_len(n_nodes)) { + if (!is_terminal[i]) { + n_compete <- frame$ncompete[i] + n_surr <- frame$nsurrogate[i] + + # Primary split + primary <- extract_one_split( + splits, + split_idx, + model, + as.character(frame$var[i]) + ) + needs_swap[i] <- primary$needs_swap + + # Surrogate splits (skip competing splits) + # Only extract if usesurrogate != 0 + surrogates <- list() + if (use_surrogates) { + surr_start <- split_idx + 1 + n_compete + for (j in seq_len(n_surr)) { + surr_idx <- surr_start + j - 1 + surr_var <- rownames(splits)[surr_idx] + surr_info <- extract_one_split(splits, surr_idx, model, surr_var) + # Adjust surrogate direction based on primary direction + if (primary$needs_swap) { + surr_info$needs_swap <- !surr_info$needs_swap + } + surrogates[[j]] <- surr_info + } + } + + # Majority direction: left child has more observations + left_id <- 2L * orig_node_ids[i] + right_id <- 2L * orig_node_ids[i] + 1L + left_n <- frame$n[orig_node_ids == left_id] + right_n <- frame$n[orig_node_ids == right_id] + majority_left[i] <- left_n >= right_n + + node_splits[[i]] <- list( + primary = primary, + surrogates = surrogates + ) + + n_splits <- 1 + n_compete + n_surr + split_idx <- split_idx + n_splits + } + } + } + + # Swap left/right based on ncat direction + # This aligns with tidypredict convention: "left" means "< split value" + left_child <- ifelse(needs_swap, rpart_right, rpart_left) + right_child <- ifelse(needs_swap, rpart_left, rpart_right) + + # Also swap majority_left to stay consistent with the new left/right + majority_left_adjusted <- ifelse(needs_swap, !majority_left, majority_left) + + list( + nodeID = seq_along(orig_node_ids) - 1L, + leftChild = left_child, + rightChild = right_child, + splitvarName = as.character(frame$var), + terminal = is_terminal, + prediction = prediction, + node_splits = node_splits, + majority_left = majority_left_adjusted, + use_surrogates = use_surrogates + ) +} + +# Extract info for a single split from the splits matrix +extract_one_split <- function(splits, idx, model, var_name) { + ncat <- splits[idx, "ncat"] + index <- splits[idx, "index"] + + if (abs(ncat) == 1) { + # Continuous split + list( + col = var_name, + val = index, + is_categorical = FALSE, + needs_swap = ncat == 1 + ) + } else { + # Categorical split + csplit_row <- model$csplit[index, , drop = TRUE] + xlevels <- attr(model, "xlevels")[[var_name]] + if (ncat > 0) { + left_levels <- xlevels[csplit_row == 1] + } else { + left_levels <- xlevels[csplit_row == 3] + } + list( + col = var_name, + vals = as.list(left_levels), + is_categorical = TRUE, + needs_swap = FALSE + ) + } +} + +# Simplified tree info for tests (returns data frame like other models) +rpart_tree_info <- function(model) { + info <- rpart_tree_info_full(model) + + split_val <- rep(NA_real_, length(info$nodeID)) + split_class <- rep(NA_character_, length(info$nodeID)) + + for (i in seq_along(info$node_splits)) { + if (!is.null(info$node_splits[[i]])) { + primary <- info$node_splits[[i]]$primary + if (primary$is_categorical) { + split_class[i] <- paste0(primary$vals, collapse = ", ") + } else { + split_val[i] <- primary$val + } + } + } + + data.frame( + nodeID = info$nodeID, + leftChild = info$leftChild, + rightChild = info$rightChild, + splitvarName = info$splitvarName, + splitval = split_val, + splitclass = split_class, + terminal = info$terminal, + prediction = info$prediction + ) +} + +get_rpart_tree <- function(model) { + tree_info <- rpart_tree_info_full(model) + terminal_ids <- tree_info$nodeID[tree_info$terminal] + internal_ids <- tree_info$nodeID[!tree_info$terminal] + + # Build parent mapping (use -1 for no parent since 0 is a valid node ID) + parent_map <- build_parent_map(tree_info) + + # Generate paths for terminal nodes (leaves) + leaf_paths <- map(terminal_ids, function(leaf_id) { + prediction <- tree_info$prediction[tree_info$nodeID == leaf_id] + if (is.factor(prediction)) { + prediction <- as.character(prediction) + } + list( + prediction = prediction, + path = build_rpart_path( + leaf_id, + tree_info, + parent_map, + use_surrogates = TRUE + ) + ) + }) + + # For usesurrogate=0, also generate paths for internal nodes where NAs stop + if (!tree_info$use_surrogates && length(internal_ids) > 0) { + na_stop_paths <- map(internal_ids, function(node_id) { + prediction <- tree_info$prediction[tree_info$nodeID == node_id] + if (is.factor(prediction)) { + prediction <- as.character(prediction) + } + # Get path to this node with explicit !is.na checks, plus final is.na check + path <- build_rpart_path( + node_id, + tree_info, + parent_map, + use_surrogates = FALSE + ) + # Add the "is.na(split_var)" condition for this node + node_idx <- which(tree_info$nodeID == node_id) + split_info <- tree_info$node_splits[[node_idx]] + if (!is.null(split_info)) { + na_cond <- list(type = "na_check", col = split_info$primary$col) + path <- c(path, list(na_cond)) + } + list(prediction = prediction, path = path) + }) + # NA stop paths come BEFORE leaf paths (more specific conditions first) + c(na_stop_paths, leaf_paths) + } else { + leaf_paths + } +} + +# Build parent mapping for path tracing +build_parent_map <- function(tree_info) { + parent_map <- rep(-1L, max(tree_info$nodeID) + 1) + for (i in seq_along(tree_info$nodeID)) { + lc <- tree_info$leftChild[i] + rc <- tree_info$rightChild[i] + if (!is.na(lc)) { + parent_map[lc + 1] <- tree_info$nodeID[i] + } + if (!is.na(rc)) parent_map[rc + 1] <- tree_info$nodeID[i] + } + parent_map +} + +# Trace path from node to root and return list of node IDs +trace_path_to_root <- function(node_id, parent_map) { + path_nodes <- node_id + current <- node_id + while (current >= 0 && (current + 1) <= length(parent_map)) { + parent <- parent_map[current + 1] + if (parent < 0) { + break + } + path_nodes <- c(path_nodes, parent) + current <- parent + } + path_nodes +} + +# Build path conditions from node to root +build_rpart_path <- function(node_id, tree_info, parent_map, use_surrogates) { + path_nodes <- trace_path_to_root(node_id, parent_map) + + if (length(path_nodes) <= 1) { + return(list()) + } + + # Build conditions for each step (child -> parent) + conditions <- list() + for (i in seq_len(length(path_nodes) - 1)) { + child_node <- path_nodes[i] + parent_node <- path_nodes[i + 1] + + parent_idx <- which(tree_info$nodeID == parent_node) + is_left_child <- tree_info$leftChild[parent_idx] == child_node + + split_info <- tree_info$node_splits[[parent_idx]] + if (is.null(split_info)) { + next + } + + primary <- split_info$primary + surrogates <- split_info$surrogates + majority_left <- tree_info$majority_left[parent_idx] + + # Build condition based on surrogate usage + if (use_surrogates) { + cond <- build_condition_with_surrogates( + primary, + surrogates, + is_left_child, + majority_left + ) + } else { + cond <- build_condition_simple(primary, is_left_child) + } + conditions <- c(conditions, list(cond)) + } + + rev(conditions) +} + +# Build condition with surrogate fallbacks (for usesurrogate=2) +build_condition_with_surrogates <- function( + primary, + surrogates, + go_left, + majority_left +) { + primary_cond <- build_split_condition(primary, go_left) + + surr_conds <- map(surrogates, function(s) { + # If surrogate needs_swap, it goes opposite of primary + surr_go_left <- if (s$needs_swap) !go_left else go_left + build_split_condition(s, surr_go_left) + }) + + # Does this direction match where majority goes? + majority_match <- (go_left && majority_left) || (!go_left && !majority_left) + + type <- if (primary$is_categorical) { + "set_with_surrogates" + } else { + "conditional_with_surrogates" + } + + list( + type = type, + primary = primary_cond, + surrogates = surr_conds, + majority_match = majority_match + ) +} + +# Build simple condition with !is.na check (for usesurrogate=0) +build_condition_simple <- function(primary, go_left) { + if (primary$is_categorical) { + list( + type = "set_not_na", + col = primary$col, + vals = primary$vals, + op = if (go_left) "in" else "not-in" + ) + } else { + list( + type = "conditional_not_na", + col = primary$col, + val = primary$val, + op = if (go_left) "less-equal" else "more" + ) + } +} + +# Build a split condition structure (used by both surrogate and simple paths) +build_split_condition <- function(split, go_left) { + if (split$is_categorical) { + list( + col = split$col, + vals = split$vals, + op = if (go_left) "in" else "not-in" + ) + } else { + list( + col = split$col, + val = split$val, + op = if (go_left) "less-equal" else "more" + ) + } +} + +#' @export +parse_model.rpart <- function(model) { + pm <- list() + pm$general$model <- "rpart" + pm$general$type <- "tree" + pm$general$version <- 2 + pm$trees <- list(get_rpart_tree(model)) + as_parsed_model(pm) +} + +#' @export +tidypredict_fit.rpart <- function(model) { + parsedmodel <- parse_model(model) + tree <- parsedmodel$trees[[1]] + mode <- parsedmodel$general$mode + generate_case_when_tree(tree, mode) +} + +#' @export +tidypredict_test.rpart <- function( + model, + df = model$model, + threshold = 0.000000000001, + include_intervals = FALSE, + + max_rows = NULL, + xg_df = NULL +) { + if (is.numeric(max_rows)) { + df <- head(df, max_rows) + } + + # rpart uses type = "vector" for regression, type = "class" for classification + pred_type <- if (model$method == "class") "class" else "vector" + base <- predict(model, df, type = pred_type) + + # For classification, threshold should be 0 (exact match) + if (model$method == "class") { + threshold <- 0 + base <- as.character(base) + } + + te <- tidypredict_to_column( + df, + model, + add_interval = FALSE, + vars = c("fit_te", "upr_te", "lwr_te") + ) + + raw_results <- data.frame(fit = base, fit_te = te$fit_te) + raw_results$fit_diff <- if (model$method == "class") { + as.numeric(raw_results$fit != raw_results$fit_te) + } else { + raw_results$fit - raw_results$fit_te + } + raw_results$fit_threshold <- abs(raw_results$fit_diff) > threshold + + rowid <- seq_len(nrow(raw_results)) + raw_results <- cbind(data.frame(rowid), raw_results) + + threshold_df <- data.frame(fit_threshold = sum(raw_results$fit_threshold)) + alert <- any(threshold_df > 0) + + message <- paste0( + "tidypredict test results\n", + "Difference threshold: ", + threshold, + "\n" + ) + + if (alert) { + difference <- max(abs(raw_results$fit_diff)) + message <- paste0( + message, + "\nFitted records above the threshold: ", + threshold_df$fit_threshold, + "\n\nMax difference: ", + difference + ) + } else { + message <- paste0( + message, + "\n All results are within the difference threshold" + ) + } + + results <- list() + results$model_call <- model$call + results$raw_results <- raw_results + results$message <- message + results$alert <- alert + structure(results, class = c("tidypredict_test", "list")) +} + +# For {orbital} +#' Extract classprob trees for rpart models +#' +#' For use in orbital package. +#' @param model An rpart model object +#' @keywords internal +#' @export +.extract_rpart_classprob <- function(model) { + if (!inherits(model, "rpart")) { + cli::cli_abort( + "{.arg model} must be {.cls rpart}, not {.obj_type_friendly {model}}." + ) + } + + if (model$method != "class") { + cli::cli_abort( + "{.arg model} must be a classification model (method = 'class')." + ) + } + + # Extract class probabilities from yval2 + # yval2 structure: [yval, count_class1, ..., count_classN, prob_class1, ..., prob_classN, nodeprob] + yval2 <- model$frame$yval2 + ylevels <- attr(model, "ylevels") + n_classes <- length(ylevels) + + # Probability columns are at positions (n_classes + 2) to (2 * n_classes + 1) + prob_cols <- seq(n_classes + 2, 2 * n_classes + 1) + probs <- yval2[, prob_cols, drop = FALSE] + colnames(probs) <- ylevels + + # Get tree structure with surrogate handling + tree_info <- rpart_tree_info_full(model) + parent_map <- build_parent_map(tree_info) + terminal_ids <- tree_info$nodeID[tree_info$terminal] + + generate_one_tree <- function(predictions) { + paths <- map(terminal_ids, function(node_id) { + node_idx <- which(tree_info$nodeID == node_id) + list( + prediction = predictions[node_idx], + path = build_rpart_path( + node_id, + tree_info, + parent_map, + use_surrogates = TRUE + ) + ) + }) + + pm <- list() + pm$general$model <- "rpart" + pm$general$type <- "tree" + pm$general$version <- 2 + pm$trees <- list(paths) + parsedmodel <- as_parsed_model(pm) + + tree <- parsedmodel$trees[[1]] + mode <- parsedmodel$general$mode + generate_case_when_tree(tree, mode) + } + + res <- list() + for (i in seq_len(ncol(probs))) { + res[[i]] <- generate_one_tree(probs[, i]) + } + res +} diff --git a/R/tree.R b/R/tree.R index 77b7f24..3611473 100644 --- a/R/tree.R +++ b/R/tree.R @@ -210,33 +210,26 @@ path_formulas <- function(path) { #' Can be character or numeric. #' @keywords internal path_formula <- function(x) { - if (x$type == "conditional") { - if (x$op == "more") { - i <- expr(!!as.name(x$col) > !!x$val) - } else if (x$op == "more-equal") { - i <- expr(!!as.name(x$col) >= !!x$val) - } else if (x$op == "less") { - i <- expr(!!as.name(x$col) < !!x$val) - } else if (x$op == "less-equal") { - i <- expr(!!as.name(x$col) <= !!x$val) - } else { - cli::cli_abort( - "{.field op} has unsupported value of {.value {x$op}}.", - .internal = TRUE - ) - } - } else if (x$type == "set") { - sets <- reduce(x$vals, c) - if (x$op == "in") { - i <- expr(!!as.name(x$col) %in% !!sets) - } else if (x$op == "not-in") { - i <- expr((!!as.name(x$col) %in% !!sets) == FALSE) - } else { - cli::cli_abort( - "{.field op} has unsupported value of {.value {x$op}}.", - .internal = TRUE - ) - } + type <- x$type + + if (type == "conditional") { + i <- build_comparison_expr(x$col, x$val, x$op) + } else if (type == "set") { + i <- build_set_expr(x$col, x$vals, x$op) + } else if ( + type == "conditional_with_surrogates" || type == "set_with_surrogates" + ) { + i <- build_surrogate_condition(x) + } else if (type == "na_check") { + i <- expr(is.na(!!as.name(x$col))) + } else if (type == "conditional_not_na") { + col <- as.name(x$col) + cond <- build_comparison_expr(x$col, x$val, x$op) + i <- expr(!is.na(!!col) & !!cond) + } else if (type == "set_not_na") { + col <- as.name(x$col) + cond <- build_set_expr(x$col, x$vals, x$op) + i <- expr(!is.na(!!col) & !!cond) } else { cli::cli_abort( "{.field type} has unsupported value of {.value {x$type}}.", @@ -246,6 +239,102 @@ path_formula <- function(x) { i } +# Build a comparison expression (col val) +build_comparison_expr <- function(col, val, op) { + col <- as.name(col) + if (op == "more") { + expr(!!col > !!val) + } else if (op == "more-equal") { + expr(!!col >= !!val) + } else if (op == "less") { + expr(!!col < !!val) + } else if (op == "less-equal") { + expr(!!col <= !!val) + } else { + cli::cli_abort( + "{.field op} has unsupported value of {.value {op}}.", + .internal = TRUE + ) + } +} + +# Build a set membership expression (col %in% vals or not) +build_set_expr <- function(col, vals, op) { + col <- as.name(col) + sets <- reduce(vals, c) + if (op == "in") { + expr(!!col %in% !!sets) + } else if (op == "not-in") { + expr((!!col %in% !!sets) == FALSE) + } else { + cli::cli_abort( + "{.field op} has unsupported value of {.value {op}}.", + .internal = TRUE + ) + } +} + +# Build condition with surrogate fallbacks for rpart +# Structure of x: +# - primary: list(col, val, op) or list(col, vals, op) for sets +# - surrogates: list of lists, each with (col, val, op) or (col, vals, op) +# - majority_match: logical, TRUE if all-NA case should match this direction +build_surrogate_condition <- function(x) { + primary <- x$primary + surrogates <- x$surrogates + majority_match <- x$majority_match + + # Build the primary condition with NOT NULL check + primary_col <- as.name(primary$col) + primary_cond <- build_single_condition(primary) + primary_expr <- expr(!is.na(!!primary_col) & !!primary_cond) + + # Start collecting all the OR terms + all_terms <- list(primary_expr) + + # Track NA checks for each level + na_checks <- list(expr(is.na(!!primary_col))) + + # Add surrogate conditions + for (surr in surrogates) { + surr_col <- as.name(surr$col) + surr_cond <- build_single_condition(surr) + + # This surrogate is used when all previous vars are NA + prev_na <- reduce_and(na_checks) + surr_expr <- expr(!!prev_na & !is.na(!!surr_col) & !!surr_cond) + + all_terms <- c(all_terms, list(surr_expr)) + na_checks <- c(na_checks, list(expr(is.na(!!surr_col)))) + } + + # Add majority case if this direction matches majority + if (isTRUE(majority_match)) { + all_na <- reduce_and(na_checks) + all_terms <- c(all_terms, list(all_na)) + } + + # Combine with OR + reduce_or(all_terms) +} + +# Build a single condition expression (without NA check) +build_single_condition <- function(cond) { + if (!is.null(cond$vals)) { + build_set_expr(cond$col, cond$vals, cond$op) + } else { + build_comparison_expr(cond$col, cond$val, cond$op) + } +} + +# Reduce expressions with OR +reduce_or <- function(exprs) { + if (length(exprs) == 1) { + return(exprs[[1]]) + } + reduce(exprs, function(a, b) expr(!!a | !!b)) +} + # For {orbital} #' Build case_when expression from nodes with predictions and paths #' diff --git a/_pkgdown.yml b/_pkgdown.yml index 3b9465a..9f542ad 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -28,6 +28,8 @@ navbar: href: articles/ranger.html - text: Random Forest - randomForest() href: articles/rf.html + - text: Decision Tree - rpart() + href: articles/rpart.html - text: MARS - earth() href: articles/mars.html - text: Cubist - cubist() diff --git a/man/dot-extract_rpart_classprob.Rd b/man/dot-extract_rpart_classprob.Rd new file mode 100644 index 0000000..408c740 --- /dev/null +++ b/man/dot-extract_rpart_classprob.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model-rpart.R +\name{.extract_rpart_classprob} +\alias{.extract_rpart_classprob} +\title{Extract classprob trees for rpart models} +\usage{ +.extract_rpart_classprob(model) +} +\arguments{ +\item{model}{An rpart model object} +} +\description{ +For use in orbital package. +} +\keyword{internal} diff --git a/man/tidypredict-package.Rd b/man/tidypredict-package.Rd index d6feea7..49e1bac 100644 --- a/man/tidypredict-package.Rd +++ b/man/tidypredict-package.Rd @@ -8,7 +8,7 @@ \description{ \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} -It parses a fitted 'R' model object, and returns a formula in 'Tidy Eval' code that calculates the predictions. It works with several databases back-ends because it leverages 'dplyr' and 'dbplyr' for the final 'SQL' translation of the algorithm. It currently supports lm(), glm(), randomForest(), ranger(), earth(), xgb.Booster.complete(), lgb.Booster(), catboost.Model(), cubist(), and ctree() models. +It parses a fitted 'R' model object, and returns a formula in 'Tidy Eval' code that calculates the predictions. It works with several databases back-ends because it leverages 'dplyr' and 'dbplyr' for the final 'SQL' translation of the algorithm. It currently supports lm(), glm(), randomForest(), ranger(), rpart(), earth(), xgb.Booster.complete(), lgb.Booster(), catboost.Model(), cubist(), and ctree() models. } \seealso{ Useful links: diff --git a/tests/testthat/_snaps/model-rpart.md b/tests/testthat/_snaps/model-rpart.md new file mode 100644 index 0000000..4ebe89b --- /dev/null +++ b/tests/testthat/_snaps/model-rpart.md @@ -0,0 +1,53 @@ +# returns the right output + + Code + rlang::expr_text(tf) + Output + [1] "case_when((!is.na(cyl) & cyl > 5 | is.na(cyl) & !is.na(am) & \n am <= 0.5 | is.na(cyl) & is.na(am)) & (!is.na(cyl) & cyl > \n 7 | is.na(cyl) & !is.na(am) & am <= 0.5 | is.na(cyl) & is.na(am)) ~ \n 15.1, (!is.na(cyl) & cyl > 5 | is.na(cyl) & !is.na(am) & \n am <= 0.5 | is.na(cyl) & is.na(am)) & (!is.na(cyl) & cyl <= \n 7 | is.na(cyl) & !is.na(am) & am > 0.5) ~ 19.7428571428571, \n .default = 26.6636363636364)" + +# formulas produce correct predictions - regression + + Code + tidypredict_test(rpart::rpart(mpg ~ am + cyl + wt, data = mtcars), mtcars) + Output + tidypredict test results + Difference threshold: 1e-12 + + All results are within the difference threshold + +# formulas produce correct predictions - classification + + Code + tidypredict_test(rpart::rpart(Species ~ ., data = iris), iris) + Output + tidypredict test results + Difference threshold: 0 + + All results are within the difference threshold + +# categorical predictors work correctly + + Code + tidypredict_test(rpart::rpart(mpg ~ cyl + wt, data = mtcars2), mtcars2) + Output + tidypredict test results + Difference threshold: 1e-12 + + All results are within the difference threshold + +# .extract_rpart_classprob errors on non-rpart model + + Code + .extract_rpart_classprob(list()) + Condition + Error in `.extract_rpart_classprob()`: + ! `model` must be , not an empty list. + +# .extract_rpart_classprob errors on regression model + + Code + .extract_rpart_classprob(model) + Condition + Error in `.extract_rpart_classprob()`: + ! `model` must be a classification model (method = 'class'). + diff --git a/tests/testthat/_snaps/tree.md b/tests/testthat/_snaps/tree.md index 0e1d29a..ef55ebf 100644 --- a/tests/testthat/_snaps/tree.md +++ b/tests/testthat/_snaps/tree.md @@ -13,7 +13,7 @@ Code path_formula(list(type = "conditional", op = "unknown", col = "x", val = 0)) Condition - Error in `path_formula()`: + Error in `build_comparison_expr()`: ! op has unsupported value of unknown. i This is an internal error that was detected in the tidypredict package. Please report it at with a reprex () and the full backtrace. @@ -23,7 +23,7 @@ Code path_formula(list(type = "set", op = "unknown", col = "x", vals = 0)) Condition - Error in `path_formula()`: + Error in `build_set_expr()`: ! op has unsupported value of unknown. i This is an internal error that was detected in the tidypredict package. Please report it at with a reprex () and the full backtrace. diff --git a/tests/testthat/test-model-rpart.R b/tests/testthat/test-model-rpart.R new file mode 100644 index 0000000..0b88efe --- /dev/null +++ b/tests/testthat/test-model-rpart.R @@ -0,0 +1,125 @@ +test_that("rpart_tree_info returns correct structure", { + model <- rpart::rpart(mpg ~ cyl + wt, data = mtcars) + tree_info <- rpart_tree_info(model) + + expect_s3_class(tree_info, "data.frame") + expect_named( + tree_info, + c( + "nodeID", + "leftChild", + "rightChild", + "splitvarName", + "splitval", + "splitclass", + "terminal", + "prediction" + ) + ) +}) + +test_that("returns the right output", { + model <- rpart::rpart(mpg ~ am + cyl, data = mtcars) + tf <- tidypredict_fit(model) + pm <- parse_model(model) + + expect_type(tf, "language") + expect_s3_class(pm, "list") + expect_equal(pm$general$model, "rpart") + expect_equal(pm$general$version, 2) + + expect_snapshot(rlang::expr_text(tf)) +}) + +test_that("Model can be saved and re-loaded", { + model <- rpart::rpart(mpg ~ am + cyl, data = mtcars) + pm <- parse_model(model) + mp <- tempfile(fileext = ".yml") + yaml::write_yaml(pm, mp) + l <- yaml::read_yaml(mp) + pm <- as_parsed_model(l) + + expect_identical( + round_print(tidypredict_fit(model)), + round_print(tidypredict_fit(pm)) + ) +}) + +test_that("formulas produce correct predictions - regression", { + expect_snapshot( + tidypredict_test( + rpart::rpart(mpg ~ am + cyl + wt, data = mtcars), + mtcars + ) + ) +}) + +test_that("formulas produce correct predictions - classification", { + expect_snapshot( + tidypredict_test( + rpart::rpart(Species ~ ., data = iris), + iris + ) + ) +}) + +test_that("categorical predictors work correctly", { + mtcars2 <- mtcars + mtcars2$cyl <- factor(mtcars2$cyl) + + expect_snapshot( + tidypredict_test( + rpart::rpart(mpg ~ cyl + wt, data = mtcars2), + mtcars2 + ) + ) +}) + +test_that("stump trees work correctly", { + ctrl <- rpart::rpart.control(minsplit = 100, cp = 1) + model <- rpart::rpart(mpg ~ cyl + disp, data = mtcars, control = ctrl) + + fit <- tidypredict_fit(model) + + expect_type(fit, "double") + expect_equal(fit, mean(mtcars$mpg)) +}) + +# .extract_rpart_classprob tests ------------------------------------------ + +test_that(".extract_rpart_classprob returns list of expressions", { + model <- rpart::rpart(Species ~ Sepal.Length + Sepal.Width, data = iris) + + exprs <- .extract_rpart_classprob(model) + + expect_type(exprs, "list") + expect_length(exprs, 3) + for (expr in exprs) { + expect_type(expr, "language") + } +}) + +test_that(".extract_rpart_classprob results match predict probabilities", { + model <- rpart::rpart(Species ~ Sepal.Length + Sepal.Width, data = iris) + + exprs <- .extract_rpart_classprob(model) + eval_env <- rlang::new_environment( + data = as.list(iris), + parent = asNamespace("dplyr") + ) + probs <- lapply(exprs, rlang::eval_tidy, env = eval_env) + combined <- do.call(cbind, probs) + + native <- predict(model, type = "prob") + + expect_equal(unname(combined), unname(native)) +}) + +test_that(".extract_rpart_classprob errors on non-rpart model", { + expect_snapshot(.extract_rpart_classprob(list()), error = TRUE) +}) + +test_that(".extract_rpart_classprob errors on regression model", { + model <- rpart::rpart(mpg ~ cyl + wt, data = mtcars) + expect_snapshot(.extract_rpart_classprob(model), error = TRUE) +}) diff --git a/vignettes/rpart.Rmd b/vignettes/rpart.Rmd new file mode 100644 index 0000000..943f25c --- /dev/null +++ b/vignettes/rpart.Rmd @@ -0,0 +1,100 @@ +--- +title: "Decision trees, using rpart" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Decision trees, using rpart} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) + +library(dplyr) +library(tidypredict) +library(rpart) +set.seed(100) +``` + +| Function |Works| +|---------------------------------------------------------------|-----| +|`tidypredict_fit()`, `tidypredict_sql()`, `parse_model()` | +|`tidypredict_to_column()` | +|`tidypredict_test()` | +|`tidypredict_interval()`, `tidypredict_sql_interval()` | +|`parsnip` | + +## How it works + +Here is a simple `rpart()` model using the `mtcars` dataset: + +```{r} +library(dplyr) +library(tidypredict) +library(rpart) + +model <- rpart(mpg ~ ., data = mtcars) +``` + +## Under the hood + +The parser extracts the tree structure from the model's `frame` and `splits` components. It handles both numeric and categorical splits, as well as surrogate splits for missing value handling. + +```{r} +model$frame |> + head() +``` + +The output from `parse_model()` is transformed into a `dplyr`, a.k.a Tidy Eval, formula. The decision tree becomes a `dplyr::case_when()` statement. + +```{r} +tidypredict_fit(model) +``` + +From there, the Tidy Eval formula can be used anywhere where it can be operated. `tidypredict` provides three paths: + +- Use directly inside `dplyr`, + `mutate(mtcars, !! tidypredict_fit(model))` +- Use `tidypredict_to_column(model)` to a piped command set +- Use `tidypredict_to_sql(model)` to retrieve the SQL statement + +## Classification + +`rpart` classification models are also supported: + +```{r} +model_class <- rpart(Species ~ ., data = iris) +tidypredict_fit(model_class) +``` + +## parsnip + +`tidypredict` also supports `rpart` model objects fitted via the `parsnip` package. + +```{r} +library(parsnip) + +parsnip_model <- decision_tree(mode = "regression") |> + set_engine("rpart") |> + fit(mpg ~ ., data = mtcars) + +tidypredict_fit(parsnip_model) +``` + +## Categorical predictors + +`rpart` handles categorical predictors natively. The generated formula uses `%in%` for categorical splits: +```{r} +mtcars2 <- mtcars +mtcars2$cyl <- factor(mtcars2$cyl) + +model_cat <- rpart(mpg ~ cyl + wt + hp, data = mtcars2) +tidypredict_fit(model_cat) +``` + +## Surrogate splits + +`rpart` uses surrogate splits to handle missing values during prediction. When the primary split variable is missing, the model uses surrogate variables (other variables that produce similar splits) to route the observation. This behavior is controlled by the `usesurrogate` parameter in `rpart.control()`.