Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ URL: https://r-nimble.org
BugReports: https://github.com/nimble-dev/nimbleMacros/issues
Encoding: UTF-8
VignetteBuilder: knitr
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ export(LINPRED_PRIORS)
export(LM)
export(matchPrior)
export(setPriors)
export(simplifyForLoops)
export(uppertri_mult_diag)
importFrom(nimble,nimMatrix)
145 changes: 145 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,148 @@ removeSquareBrackets <- function(code){
}
out
}

# Get for loop index range from a chunk of code
forInfo <- function(x){
if(is.symbol(x)) return(NULL)
if(x[[1]] != "for") return(NULL)
as.list(x[[3]])
}

# Attempt to combine for loops that share the same index range
collapseLoopsInternal <- function(code){
if(is.symbol(code)) return(code)
# Iterate over call components (except last one)
for (i in 1:(length(code)-1)){
# Skip if just a symbol
if(is.symbol(code[[i]])) next
# Skip if not a for loop
if(code[[i]][[1]] != "for") next
# Get index range info from for loop
i_info <- forInfo(code[[i]])
# Iterate over all subsequent call components after this one
# looking for matching index ranges
for (j in (i+1):length(code)){
# Skip if just a symbol
if(is.symbol(code[[j]])) next
# Skip if not a for loop
if(code[[j]][[1]] != "for") next
# Get index range info from the new for loop
j_info <- forInfo(code[[j]])
# Check if the two index ranges match
if(identical(i_info, j_info)){
# Save existing for loop code for "parent" loop into new variable
newloop <- code[[i]]
# Get loop index for this loop
idx_i <- newloop[[2]]
# Separate code inside loop
internal <- newloop[[4]]
# Drop the containing bracket from it
internal[[1]] <- NULL
# Get the code inside the "child" loop which will be added
# to the parent loop
# Note: could be a list if there is more than one line
add_code <- code[[j]][[4]]
# Remove bracket
add_code[[1]] <- NULL
# Get loop index for the "child" loop
idx_j <- code[[j]][[2]]
# Replace the existing index with the index from the parent loop
add_code <- lapply(add_code, recursiveReplaceIndex, idx_j, idx_i)
# Combine the parent and child loop code and insert it back into the for loop
newloop[[4]] <- embedLinesInCurlyBrackets(c(internal, add_code))
# Insert the for loop into the full code
code[[i]] <- newloop
# Mark the now-duplicated "child" code for removal later
code[[j]] <- "_REMOVE_"
}
}
}

# Return all code parts except the stuff to be removed
code[!sapply(code, function(x) x == "_REMOVE_")]
}

# Recursively combine loops that share common index ranges
collapseLoops <- function(code){
# Run the internal loop collapsing code once
code <- collapseLoopsInternal(code)
# Iterate over the result, looking for internal for loops and collapsing those
if(is.call(code)){
out <- lapply(code, function(x){
if(is.symbol(x)) return(x)
if(x[[1]] == "for"){
x[[4]] <- collapseLoops(x[[4]])
}
x
})
out <- as.call(out)
} else {
out <- code
}
out
}

# Replace complex indices (i_1, i_2, etc.) with a smaller number of
# simplier indices (i, j, etc.) if possible
simplifyIndices <- function(code, new_indices){
out <- lapply(code, function(x){
if(is.name(x) | is.symbol(x)) return(x)
if(x[[1]] == "for"){
unique_idx <- unique(extractAllIndices(x))
if(length(unique_idx) > length(new_indices)){
stop("Not enough new indices provided", call.=FALSE)
}
for (i in 1:length(unique_idx)){
x <- replaceForLoopIndex(x, unique_idx[[i]], new_indices[[i]])
}
}
x
})
as.call(out)
}

# Replace the index in a loop recursively
replaceForLoopIndex <- function(code, idx, newidx){
if(is.name(code) | is.symbol(code)) return(code)
if(code[[1]] == "for"){
if(code[[2]] == idx) code[[2]] <- newidx
code[[4]] <- recursiveReplaceIndex(code[[4]], idx, newidx)
if(is.call(code[[4]])){
code[[4]] <- as.call(lapply(code[[4]], function(x)
replaceForLoopIndex(x, idx, newidx)))
}
}
code
}

#' Simplify for loop structure in NIMBLE model code
#'
#' Takes the code for a NIMBLE model and attempts to combine for loops
#' that share the same index range, in order to simplify the code
#' structure. Optionally, can also replace existing for loop indices
#' with a (potentially smaller, simpler) set of new indices.
#' This function is particularly useful for simplifying code generated by
#' macros, which often creates many for loops with the same indices and
#' uses complex indices like 'i_1', 'i_2', etc. which could be simplified to
#' 'i', 'j', etc.
#'
#' @author Ken Kellner
#'
#' @param code NIMBLE code for a model, such as from the output of model$getCode()
#' @param new_indices A list of new for loop indices that will replace the existing
#' indices. The new indices must be quoted values (i.e., "names"/symbols).
#' If NULL, letters starting with 'i' will be used. If FALSE, no indices will
#' be replaced.
#'
#' @export
simplifyForLoops <- function(code, new_indices = NULL){
out <- collapseLoops(code)
if(is.null(new_indices)){
new_indices <- lapply(letters[9:26], str2lang)
} else if(is.logical(new_indices) && !new_indices){
return(out)
}
out <- simplifyIndices(out, new_indices)
out
}
29 changes: 29 additions & 0 deletions man/simplifyForLoops.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

125 changes: 125 additions & 0 deletions tests/testthat/test_simplifyForLoops.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
context("simplifyForLoops")

skip_on_cran()

test_that("Nothing is done when there are no for loops", {
test <- nimbleCode({
x <- 1
y <- 1
})

out <- simplifyForLoops(test)

expect_equal(out, test)
})

test_that("Basic for loop collapsing", {

test <- nimbleCode({
for (i in 1:3){
x[i] <- y[i] + 1
}

for (j in 1:3){
x2[j] <- y2[j] + 1
x3[j] <- y3[j] + 1
}
})

out <- simplifyForLoops(test)

expect_equal(out,
quote({
for (i in 1:3){
x[i] <- y[i] + 1
x2[i] <- y2[i] + 1
x3[i] <- y3[i] + 1
}
})
)
})

test_that("More complex case of an occupancy model", {

nimbleOptions(enableMacroComments = FALSE)
occ <- nimbleCode({
psi[1:nsites] <- LINPRED(~scale(x[1:nsites]), link=logit, coefPrefix=state_)
p[1:nsites, 1:noccs] <- LINPRED(~x[1:nsites] + x2[1:nsites, 1:noccs], link=logit, coefPrefix=det_)

z[1:nsites] ~ FORLOOP(dbern(psi[1:nsites]))
y[1:nsites, 1:noccs] ~ FORLOOP(dbern(p[1:nsites, 1:noccs]*z[1:nsites]))
})

const <- list(nsites=10, noccs=3, y=matrix(0, 10, 3), x=rnorm(10),
x2=matrix(rnorm(30),10,3))

mod <- nimbleModel(occ, constants=const)

out <- simplifyForLoops(mod$getCode())

expect_equal(out,
quote({
for (i in 1:nsites) {
logit(psi[i]) <- state_Intercept + state_x_scaled * x_scaled[i]
for (j in 1:noccs) {
logit(p[i, j]) <- det_Intercept + det_x * x[i] + det_x2 * x2[i, j]
y[i, j] ~ dbern(p[i, j] * z[i])
}
z[i] ~ dbern(psi[i])
}
state_Intercept ~ dnorm(0, sd = 1000)
state_x_scaled ~ dnorm(0, sd = 1000)
det_Intercept ~ dnorm(0, sd = 1000)
det_x ~ dnorm(0, sd = 1000)
det_x2 ~ dnorm(0, sd = 1000)
})
)

# Change indices
out <- simplifyForLoops(mod$getCode(), new_indices=list(quote(f), quote(g)))

expect_equal(out,
quote({
for (f in 1:nsites) {
logit(psi[f]) <- state_Intercept + state_x_scaled * x_scaled[f]
for (g in 1:noccs) {
logit(p[f, g]) <- det_Intercept + det_x * x[f] + det_x2 * x2[f, g]
y[f, g] ~ dbern(p[f, g] * z[f])
}
z[f] ~ dbern(psi[f])
}
state_Intercept ~ dnorm(0, sd = 1000)
state_x_scaled ~ dnorm(0, sd = 1000)
det_Intercept ~ dnorm(0, sd = 1000)
det_x ~ dnorm(0, sd = 1000)
det_x2 ~ dnorm(0, sd = 1000)
})
)

# Don't change indices
out <- simplifyForLoops(mod$getCode(), new_indices=FALSE)

expect_equal(out,
quote({
for (i_1 in 1:nsites) {
logit(psi[i_1]) <- state_Intercept + state_x_scaled *
x_scaled[i_1]
for (i_3 in 1:noccs) {
logit(p[i_1, i_3]) <- det_Intercept + det_x * x[i_1] +
det_x2 * x2[i_1, i_3]
y[i_1, i_3] ~ dbern(p[i_1, i_3] * z[i_1])
}
z[i_1] ~ dbern(psi[i_1])
}
state_Intercept ~ dnorm(0, sd = 1000)
state_x_scaled ~ dnorm(0, sd = 1000)
det_Intercept ~ dnorm(0, sd = 1000)
det_x ~ dnorm(0, sd = 1000)
det_x2 ~ dnorm(0, sd = 1000)
})
)

# Error when not enough indices are provided
expect_error(simplifyForLoops(mod$getCode(), new_indices=list(quote(i))),
"Not enough new indices provided")
})