speed up the model code#229
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Results
ranger
Issue 1: Data frame row subsetting in recursive function
The original code in
build_nested_ranger_node()usedtree[tree$nodeID == node_id, ]on every recursive call, which is O(n) per node.Fix: Pre-extract all columns as vectors before recursion and use direct vector indexing.
Issue 2: Named vector lookup with
as.character()conversionAfter round 1, profiling showed 55% of time spent on
id_to_idx[[as.character(node_id)]]- theas.character()call on every recursive iteration was expensive.Fix: Since ranger nodeIDs are 0-indexed and sequential, use direct integer indexing
node_id + 1Linstead of named lookup.Total improvement:
Final profile breakdown (500 trees, depth=10):
No further obvious optimizations - remaining time is inherent to rlang expression building.
lightgbm
Issue 1: Duplicate JSON parsing
The original code parsed
model$dump_model()viajsonlite::fromJSON()twice: once inparse_model.lgb.Booster()and again inextract_lgb_trees_nested(). JSON parsing took 57% of total time.Fix: Pass
feature_namesfromparsedmodeltoextract_lgb_trees_nested()to avoid redundant JSON parsing.Issue 2: Slow
grepl()for string matchingThe
parse_lgb_linear_trees()function usedgrepl("^prefix", line)in a loop over every line of the model string.Fix: Replace
grepl("^prefix", ...)withstartsWith(line, "prefix")which is faster for prefix matching.Issue 3: Data frame column access in path building
The
get_lgb_path()function usedwhich(tree_df$split_index == current_parent_split)(O(n) per iteration) and repeatedtree_df$column[[parent_row]]access.Fix: Pre-extract columns as vectors and build a split_index to row lookup array for O(1) access.
randomForest
Issue: Matrix row access and repeated
unname()callsThe original code used
tree[node_id, ]to get a row, then calledunname()on each column value (4-5 times per internal node). This caused:unname()callsunname()aloneFix: Pre-extract all matrix columns as vectors with
unname()called once per column, then use direct integer indexing.Final profile breakdown (500 trees):
No further obvious optimizations - remaining time is inherent to rlang expression building.
catboost
Issue: Repeated feature info lookups in oblivious tree parsing
For oblivious trees (same splits for all leaves), the original code parsed each split separately for every leaf. With depth=6 (64 leaves), each of the 6 splits was looked up 64 times per tree instead of once.
Fix: Pre-extract split info once per tree, then for each leaf only determine the direction (op) based on bit value.
Note: catboost was already quite fast. Most time is spent in JSON parsing (saving model to JSON file, then reading it), which is unavoidable.
cubist
Issue: Repeated data frame subsetting with
==in nested loopsThe original code used
coefs[coefs$rule == .x & coefs$committee == comm, ]inside nested loops, causing O(n) scans for every rule in every committee. With 100 committees and many rules, this became O(n * committees * rules_per_committee).Fix: Pre-split data frames by committee and rule using
split()once, then use direct hash lookup by key.The speedup scales with model complexity - larger models see more benefit.
rpart
Already fast - no optimization needed. <10ms for 617 nodes.
partykit
Issue: Repeated tree traversal via
nodeapply()andmodel[[.x]]The original code called
partykit::nodeapply(model, .x)andmodel[[.x]]for every single node. Each call traverses the entire tree to find that node, resulting in O(n²) complexity. 73.7% of time was spent inrid/nodeids.partynode(partykit's internal recursive tree traversal).Fix: Extract all nodes at once using
partykit::nodeapply(model, ids = all_node_ids, FUN = identity)and compute predictions usingtapply()instead of per-node iteration.The speedup is dramatic because we eliminated O(n²) tree traversals.
lm, glm, glmnet, earth
All regression-based models are already fast - no optimization needed.
These models are simple coefficient-based formulas with minimal computation required.
Summary
All 11 model types have been profiled. Optimizations were made to 7 models:
unname()Common optimization patterns:
as.character()nodeapply) per-node