diff --git a/DESCRIPTION b/DESCRIPTION index 010d895..227b44b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,4 +23,4 @@ URL: https://r-nimble.org BugReports: https://github.com/nimble-dev/nimbleMacros/issues Encoding: UTF-8 VignetteBuilder: knitr -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 diff --git a/NAMESPACE b/NAMESPACE index 715aafa..5a9bd19 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,5 +6,6 @@ export(LINPRED_PRIORS) export(LM) export(matchPrior) export(setPriors) +export(simplifyForLoops) export(uppertri_mult_diag) importFrom(nimble,nimMatrix) diff --git a/R/utilities.R b/R/utilities.R index d08d2ff..3b99a25 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -106,3 +106,148 @@ removeSquareBrackets <- function(code){ } out } + +# Get for loop index range from a chunk of code +forInfo <- function(x){ + if(is.symbol(x)) return(NULL) + if(x[[1]] != "for") return(NULL) + as.list(x[[3]]) +} + +# Attempt to combine for loops that share the same index range +collapseLoopsInternal <- function(code){ + if(is.symbol(code)) return(code) + # Iterate over call components (except last one) + for (i in 1:(length(code)-1)){ + # Skip if just a symbol + if(is.symbol(code[[i]])) next + # Skip if not a for loop + if(code[[i]][[1]] != "for") next + # Get index range info from for loop + i_info <- forInfo(code[[i]]) + # Iterate over all subsequent call components after this one + # looking for matching index ranges + for (j in (i+1):length(code)){ + # Skip if just a symbol + if(is.symbol(code[[j]])) next + # Skip if not a for loop + if(code[[j]][[1]] != "for") next + # Get index range info from the new for loop + j_info <- forInfo(code[[j]]) + # Check if the two index ranges match + if(identical(i_info, j_info)){ + # Save existing for loop code for "parent" loop into new variable + newloop <- code[[i]] + # Get loop index for this loop + idx_i <- newloop[[2]] + # Separate code inside loop + internal <- newloop[[4]] + # Drop the containing bracket from it + internal[[1]] <- NULL + # Get the code inside the "child" loop which will be added + # to the parent loop + # Note: could be a list if there is more than one line + add_code <- code[[j]][[4]] + # Remove bracket + add_code[[1]] <- NULL + # Get loop index for the "child" loop + idx_j <- code[[j]][[2]] + # Replace the existing index with the index from the parent loop + add_code <- lapply(add_code, recursiveReplaceIndex, idx_j, idx_i) + # Combine the parent and child loop code and insert it back into the for loop + newloop[[4]] <- embedLinesInCurlyBrackets(c(internal, add_code)) + # Insert the for loop into the full code + code[[i]] <- newloop + # Mark the now-duplicated "child" code for removal later + code[[j]] <- "_REMOVE_" + } + } + } + + # Return all code parts except the stuff to be removed + code[!sapply(code, function(x) x == "_REMOVE_")] +} + +# Recursively combine loops that share common index ranges +collapseLoops <- function(code){ + # Run the internal loop collapsing code once + code <- collapseLoopsInternal(code) + # Iterate over the result, looking for internal for loops and collapsing those + if(is.call(code)){ + out <- lapply(code, function(x){ + if(is.symbol(x)) return(x) + if(x[[1]] == "for"){ + x[[4]] <- collapseLoops(x[[4]]) + } + x + }) + out <- as.call(out) + } else { + out <- code + } + out +} + +# Replace complex indices (i_1, i_2, etc.) with a smaller number of +# simplier indices (i, j, etc.) if possible +simplifyIndices <- function(code, new_indices){ + out <- lapply(code, function(x){ + if(is.name(x) | is.symbol(x)) return(x) + if(x[[1]] == "for"){ + unique_idx <- unique(extractAllIndices(x)) + if(length(unique_idx) > length(new_indices)){ + stop("Not enough new indices provided", call.=FALSE) + } + for (i in 1:length(unique_idx)){ + x <- replaceForLoopIndex(x, unique_idx[[i]], new_indices[[i]]) + } + } + x + }) + as.call(out) +} + +# Replace the index in a loop recursively +replaceForLoopIndex <- function(code, idx, newidx){ + if(is.name(code) | is.symbol(code)) return(code) + if(code[[1]] == "for"){ + if(code[[2]] == idx) code[[2]] <- newidx + code[[4]] <- recursiveReplaceIndex(code[[4]], idx, newidx) + if(is.call(code[[4]])){ + code[[4]] <- as.call(lapply(code[[4]], function(x) + replaceForLoopIndex(x, idx, newidx))) + } + } + code +} + +#' Simplify for loop structure in NIMBLE model code +#' +#' Takes the code for a NIMBLE model and attempts to combine for loops +#' that share the same index range, in order to simplify the code +#' structure. Optionally, can also replace existing for loop indices +#' with a (potentially smaller, simpler) set of new indices. +#' This function is particularly useful for simplifying code generated by +#' macros, which often creates many for loops with the same indices and +#' uses complex indices like 'i_1', 'i_2', etc. which could be simplified to +#' 'i', 'j', etc. +#' +#' @author Ken Kellner +#' +#' @param code NIMBLE code for a model, such as from the output of model$getCode() +#' @param new_indices A list of new for loop indices that will replace the existing +#' indices. The new indices must be quoted values (i.e., "names"/symbols). +#' If NULL, letters starting with 'i' will be used. If FALSE, no indices will +#' be replaced. +#' +#' @export +simplifyForLoops <- function(code, new_indices = NULL){ + out <- collapseLoops(code) + if(is.null(new_indices)){ + new_indices <- lapply(letters[9:26], str2lang) + } else if(is.logical(new_indices) && !new_indices){ + return(out) + } + out <- simplifyIndices(out, new_indices) + out +} diff --git a/man/simplifyForLoops.Rd b/man/simplifyForLoops.Rd new file mode 100644 index 0000000..d5f7a19 --- /dev/null +++ b/man/simplifyForLoops.Rd @@ -0,0 +1,29 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utilities.R +\name{simplifyForLoops} +\alias{simplifyForLoops} +\title{Simplify for loop structure in NIMBLE model code} +\usage{ +simplifyForLoops(code, new_indices = NULL) +} +\arguments{ +\item{code}{NIMBLE code for a model, such as from the output of model$getCode()} + +\item{new_indices}{A list of new for loop indices that will replace the existing +indices. The new indices must be quoted values (i.e., "names"/symbols). +If NULL, letters starting with 'i' will be used. If FALSE, no indices will +be replaced.} +} +\description{ +Takes the code for a NIMBLE model and attempts to combine for loops +that share the same index range, in order to simplify the code +structure. Optionally, can also replace existing for loop indices +with a (potentially smaller, simpler) set of new indices. +This function is particularly useful for simplifying code generated by +macros, which often creates many for loops with the same indices and +uses complex indices like 'i_1', 'i_2', etc. which could be simplified to +'i', 'j', etc. +} +\author{ +Ken Kellner +} diff --git a/tests/testthat/test_simplifyForLoops.R b/tests/testthat/test_simplifyForLoops.R new file mode 100644 index 0000000..abeeda7 --- /dev/null +++ b/tests/testthat/test_simplifyForLoops.R @@ -0,0 +1,125 @@ +context("simplifyForLoops") + +skip_on_cran() + +test_that("Nothing is done when there are no for loops", { + test <- nimbleCode({ + x <- 1 + y <- 1 + }) + + out <- simplifyForLoops(test) + + expect_equal(out, test) +}) + +test_that("Basic for loop collapsing", { + + test <- nimbleCode({ + for (i in 1:3){ + x[i] <- y[i] + 1 + } + + for (j in 1:3){ + x2[j] <- y2[j] + 1 + x3[j] <- y3[j] + 1 + } + }) + + out <- simplifyForLoops(test) + + expect_equal(out, + quote({ + for (i in 1:3){ + x[i] <- y[i] + 1 + x2[i] <- y2[i] + 1 + x3[i] <- y3[i] + 1 + } + }) + ) +}) + +test_that("More complex case of an occupancy model", { + + nimbleOptions(enableMacroComments = FALSE) + occ <- nimbleCode({ + psi[1:nsites] <- LINPRED(~scale(x[1:nsites]), link=logit, coefPrefix=state_) + p[1:nsites, 1:noccs] <- LINPRED(~x[1:nsites] + x2[1:nsites, 1:noccs], link=logit, coefPrefix=det_) + + z[1:nsites] ~ FORLOOP(dbern(psi[1:nsites])) + y[1:nsites, 1:noccs] ~ FORLOOP(dbern(p[1:nsites, 1:noccs]*z[1:nsites])) + }) + + const <- list(nsites=10, noccs=3, y=matrix(0, 10, 3), x=rnorm(10), + x2=matrix(rnorm(30),10,3)) + + mod <- nimbleModel(occ, constants=const) + + out <- simplifyForLoops(mod$getCode()) + + expect_equal(out, + quote({ + for (i in 1:nsites) { + logit(psi[i]) <- state_Intercept + state_x_scaled * x_scaled[i] + for (j in 1:noccs) { + logit(p[i, j]) <- det_Intercept + det_x * x[i] + det_x2 * x2[i, j] + y[i, j] ~ dbern(p[i, j] * z[i]) + } + z[i] ~ dbern(psi[i]) + } + state_Intercept ~ dnorm(0, sd = 1000) + state_x_scaled ~ dnorm(0, sd = 1000) + det_Intercept ~ dnorm(0, sd = 1000) + det_x ~ dnorm(0, sd = 1000) + det_x2 ~ dnorm(0, sd = 1000) + }) + ) + + # Change indices + out <- simplifyForLoops(mod$getCode(), new_indices=list(quote(f), quote(g))) + + expect_equal(out, + quote({ + for (f in 1:nsites) { + logit(psi[f]) <- state_Intercept + state_x_scaled * x_scaled[f] + for (g in 1:noccs) { + logit(p[f, g]) <- det_Intercept + det_x * x[f] + det_x2 * x2[f, g] + y[f, g] ~ dbern(p[f, g] * z[f]) + } + z[f] ~ dbern(psi[f]) + } + state_Intercept ~ dnorm(0, sd = 1000) + state_x_scaled ~ dnorm(0, sd = 1000) + det_Intercept ~ dnorm(0, sd = 1000) + det_x ~ dnorm(0, sd = 1000) + det_x2 ~ dnorm(0, sd = 1000) + }) + ) + + # Don't change indices + out <- simplifyForLoops(mod$getCode(), new_indices=FALSE) + + expect_equal(out, + quote({ + for (i_1 in 1:nsites) { + logit(psi[i_1]) <- state_Intercept + state_x_scaled * + x_scaled[i_1] + for (i_3 in 1:noccs) { + logit(p[i_1, i_3]) <- det_Intercept + det_x * x[i_1] + + det_x2 * x2[i_1, i_3] + y[i_1, i_3] ~ dbern(p[i_1, i_3] * z[i_1]) + } + z[i_1] ~ dbern(psi[i_1]) + } + state_Intercept ~ dnorm(0, sd = 1000) + state_x_scaled ~ dnorm(0, sd = 1000) + det_Intercept ~ dnorm(0, sd = 1000) + det_x ~ dnorm(0, sd = 1000) + det_x2 ~ dnorm(0, sd = 1000) + }) + ) + + # Error when not enough indices are provided + expect_error(simplifyForLoops(mod$getCode(), new_indices=list(quote(i))), + "Not enough new indices provided") +})