Skip to content

Commit dc486da

Browse files
authored
Support for nestedLogit (#606)
* Support for nestedLogit * add test
1 parent 323596b commit dc486da

File tree

5 files changed

+232
-97
lines changed

5 files changed

+232
-97
lines changed

DESCRIPTION

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Type: Package
22
Package: modelbased
33
Title: Estimation of Model-Based Predictions, Contrasts and Means
4-
Version: 0.14.0
4+
Version: 0.14.0.1
55
Authors@R:
66
c(person(given = "Dominique",
77
family = "Makowski",
@@ -55,6 +55,9 @@ Suggests:
5555
bootES,
5656
brglm2,
5757
brms,
58+
broom,
59+
car,
60+
carData,
5861
coda,
5962
collapse,
6063
correlation,
@@ -82,6 +85,7 @@ Suggests:
8285
mgcv,
8386
mvtnorm,
8487
nanoparquet,
88+
nestedLogit,
8589
nnet,
8690
ordinal,
8791
palmerpenguins,

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# modelbased (devel)
2+
3+
## Changes
4+
5+
* Support for models of class `nestedLogit`.
6+
17
# modelbased 0.14.0
28

39
## Changes

R/estimate_predicted.R

Lines changed: 104 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@
200200
#' [insight::get_datagrid()] (used when `data = "grid"`) and
201201
#' [insight::get_predicted()]. Furthermore, for count regression models that use
202202
#' an offset term, use `offset = <value>` to fix the offset at a specific value.
203+
#' For models of class `nestedLogit`, use the `submodel` argument to specify
204+
#' the component for which predictions should be returned (see
205+
#' `?insight::get_predicted` for details).
203206
#'
204207
#' @return A data frame of predicted values and uncertainty intervals, with
205208
#' class `"estimate_predicted"`. Methods for [`visualisation_recipe()`][visualisation_recipe.estimate_predicted]
@@ -255,15 +258,17 @@
255258
#' estimate_relation(model)
256259
#' }
257260
#' @export
258-
estimate_expectation <- function(model,
259-
data = NULL,
260-
by = NULL,
261-
predict = "expectation",
262-
ci = 0.95,
263-
transform = NULL,
264-
iterations = NULL,
265-
keep_iterations = FALSE,
266-
...) {
261+
estimate_expectation <- function(
262+
model,
263+
data = NULL,
264+
by = NULL,
265+
predict = "expectation",
266+
ci = 0.95,
267+
transform = NULL,
268+
iterations = NULL,
269+
keep_iterations = FALSE,
270+
...
271+
) {
267272
.estimate_predicted(
268273
model,
269274
data = data,
@@ -280,15 +285,17 @@ estimate_expectation <- function(model,
280285

281286
#' @rdname estimate_expectation
282287
#' @export
283-
estimate_link <- function(model,
284-
data = "grid",
285-
by = NULL,
286-
predict = "link",
287-
ci = 0.95,
288-
transform = NULL,
289-
iterations = NULL,
290-
keep_iterations = FALSE,
291-
...) {
288+
estimate_link <- function(
289+
model,
290+
data = "grid",
291+
by = NULL,
292+
predict = "link",
293+
ci = 0.95,
294+
transform = NULL,
295+
iterations = NULL,
296+
keep_iterations = FALSE,
297+
...
298+
) {
292299
# reset to NULL if only "by" was specified
293300
if (missing(data) && !missing(by)) {
294301
data <- NULL
@@ -309,15 +316,17 @@ estimate_link <- function(model,
309316

310317
#' @rdname estimate_expectation
311318
#' @export
312-
estimate_prediction <- function(model,
313-
data = NULL,
314-
by = NULL,
315-
predict = "prediction",
316-
ci = 0.95,
317-
transform = NULL,
318-
iterations = NULL,
319-
keep_iterations = FALSE,
320-
...) {
319+
estimate_prediction <- function(
320+
model,
321+
data = NULL,
322+
by = NULL,
323+
predict = "prediction",
324+
ci = 0.95,
325+
transform = NULL,
326+
iterations = NULL,
327+
keep_iterations = FALSE,
328+
...
329+
) {
321330
.estimate_predicted(
322331
model,
323332
data = data,
@@ -333,15 +342,17 @@ estimate_prediction <- function(model,
333342

334343
#' @rdname estimate_expectation
335344
#' @export
336-
estimate_relation <- function(model,
337-
data = "grid",
338-
by = NULL,
339-
predict = "expectation",
340-
ci = 0.95,
341-
transform = NULL,
342-
iterations = NULL,
343-
keep_iterations = FALSE,
344-
...) {
345+
estimate_relation <- function(
346+
model,
347+
data = "grid",
348+
by = NULL,
349+
predict = "expectation",
350+
ci = 0.95,
351+
transform = NULL,
352+
iterations = NULL,
353+
keep_iterations = FALSE,
354+
...
355+
) {
345356
# reset to NULL if only "by" was specified
346357
if (missing(data) && !missing(by)) {
347358
data <- NULL
@@ -364,15 +375,17 @@ estimate_relation <- function(model,
364375
# Internal ----------------------------------------------------------------
365376

366377
#' @keywords internal
367-
.estimate_predicted <- function(model,
368-
data = "grid",
369-
by = NULL,
370-
predict = "expectation",
371-
ci = 0.95,
372-
transform = NULL,
373-
iterations = NULL,
374-
keep_iterations = FALSE,
375-
...) {
378+
.estimate_predicted <- function(
379+
model,
380+
data = "grid",
381+
by = NULL,
382+
predict = "expectation",
383+
ci = 0.95,
384+
transform = NULL,
385+
iterations = NULL,
386+
keep_iterations = FALSE,
387+
...
388+
) {
376389
# return early for htest
377390
if (inherits(model, "htest")) {
378391
return(insight::get_predicted(model, ...))
@@ -384,7 +397,14 @@ estimate_relation <- function(model,
384397
}
385398

386399
# keep_iterations cannot be larger than interations
387-
if (!is.null(keep_iterations) && !is.null(iterations) && is.numeric(keep_iterations) && is.numeric(iterations) && keep_iterations > iterations) { # nolint
400+
if (
401+
!is.null(keep_iterations) &&
402+
!is.null(iterations) &&
403+
is.numeric(keep_iterations) &&
404+
is.numeric(iterations) &&
405+
keep_iterations > iterations
406+
) {
407+
# nolint
388408
insight::format_error("`keep_iterations` cannot be larger than `iterations`.")
389409
}
390410

@@ -450,7 +470,12 @@ estimate_relation <- function(model,
450470
data <- model_data
451471
} else if (!is.data.frame(data)) {
452472
if (is_grid) {
453-
data <- insight::get_datagrid(model, reference = model_data, include_response = is_nullmodel, ...)
473+
data <- insight::get_datagrid(
474+
model,
475+
reference = model_data,
476+
include_response = is_nullmodel,
477+
...
478+
)
454479
} else {
455480
insight::format_error(
456481
"The `data` argument must either NULL, \"grid\" or another data frame."
@@ -462,7 +487,12 @@ estimate_relation <- function(model,
462487
grid_specs <- attributes(data)
463488

464489
# Get response for later residuals -------------
465-
if (!is.null(model_response) && length(model_response) == 1 && model_response %in% names(data)) { # nolint
490+
if (
491+
!is.null(model_response) &&
492+
length(model_response) == 1 &&
493+
model_response %in% names(data)
494+
) {
495+
# nolint
466496
response <- data[[model_response]]
467497
} else {
468498
response <- NULL
@@ -492,7 +522,10 @@ estimate_relation <- function(model,
492522
)
493523

494524
# for predicting grouplevel random effects, add "allow.new.levels"
495-
if (!is.null(grouplevel_effects) && any(grouplevel_effects %in% grid_specs$at_spec$varname)) {
525+
if (
526+
!is.null(grouplevel_effects) &&
527+
any(grouplevel_effects %in% grid_specs$at_spec$varname)
528+
) {
496529
prediction_args$allow.new.levels <- TRUE
497530
dots$allow.new.levels <- NULL
498531
}
@@ -511,13 +544,22 @@ estimate_relation <- function(model,
511544
}
512545

513546
# remove response variable from data frame, as this variable is predicted
514-
if (!is.null(model_response) && length(model_response) == 1 && model_response %in% colnames(out)) { # nolint
547+
if (
548+
!is.null(model_response) &&
549+
length(model_response) == 1 &&
550+
model_response %in% colnames(out)
551+
) {
552+
# nolint
515553
out[[model_response]] <- NULL
516554
}
517555

518556
# keep row-column, but make sure it's integer
519557
if ("Row" %in% colnames(out)) {
520-
out[["Row"]] <- insight::format_value(out[["Row"]], protect_integers = TRUE)
558+
if (inherits(model, "nestedLogit")) {
559+
out[["Row"]] <- NULL
560+
} else {
561+
out[["Row"]] <- insight::format_value(out[["Row"]], protect_integers = TRUE)
562+
}
521563
}
522564

523565
# Add residuals
@@ -557,17 +599,21 @@ estimate_relation <- function(model,
557599
by = grid_specs$at,
558600
type = "predictions",
559601
model = model,
560-
info = c(
561-
grid_specs,
562-
list(predict = predict),
563-
transform = !is.null(transform)
564-
)
602+
info = c(grid_specs, list(predict = predict), transform = !is.null(transform))
565603
)
566604

567-
attributes(out) <- c(attributes(out), grid_specs[!names(grid_specs) %in% names(attributes(out))])
605+
attributes(out) <- c(
606+
attributes(out),
607+
grid_specs[!names(grid_specs) %in% names(attributes(out))]
608+
)
568609

569610
# Class
570-
class(out) <- c(paste0("estimate_", predict), "estimate_predicted", "see_estimate_predicted", class(out))
611+
class(out) <- c(
612+
paste0("estimate_", predict),
613+
"estimate_predicted",
614+
"see_estimate_predicted",
615+
class(out)
616+
)
571617

572618
out
573619
}

man/estimate_expectation.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)