200200# ' [insight::get_datagrid()] (used when `data = "grid"`) and
201201# ' [insight::get_predicted()]. Furthermore, for count regression models that use
202202# ' an offset term, use `offset = <value>` to fix the offset at a specific value.
203+ # ' For models of class `nestedLogit`, use the `submodel` argument to specify
204+ # ' the component for which predictions should be returned (see
205+ # ' `?insight::get_predicted` for details).
203206# '
204207# ' @return A data frame of predicted values and uncertainty intervals, with
205208# ' class `"estimate_predicted"`. Methods for [`visualisation_recipe()`][visualisation_recipe.estimate_predicted]
255258# ' estimate_relation(model)
256259# ' }
257260# ' @export
258- estimate_expectation <- function (model ,
259- data = NULL ,
260- by = NULL ,
261- predict = " expectation" ,
262- ci = 0.95 ,
263- transform = NULL ,
264- iterations = NULL ,
265- keep_iterations = FALSE ,
266- ... ) {
261+ estimate_expectation <- function (
262+ model ,
263+ data = NULL ,
264+ by = NULL ,
265+ predict = " expectation" ,
266+ ci = 0.95 ,
267+ transform = NULL ,
268+ iterations = NULL ,
269+ keep_iterations = FALSE ,
270+ ...
271+ ) {
267272 .estimate_predicted(
268273 model ,
269274 data = data ,
@@ -280,15 +285,17 @@ estimate_expectation <- function(model,
280285
281286# ' @rdname estimate_expectation
282287# ' @export
283- estimate_link <- function (model ,
284- data = " grid" ,
285- by = NULL ,
286- predict = " link" ,
287- ci = 0.95 ,
288- transform = NULL ,
289- iterations = NULL ,
290- keep_iterations = FALSE ,
291- ... ) {
288+ estimate_link <- function (
289+ model ,
290+ data = " grid" ,
291+ by = NULL ,
292+ predict = " link" ,
293+ ci = 0.95 ,
294+ transform = NULL ,
295+ iterations = NULL ,
296+ keep_iterations = FALSE ,
297+ ...
298+ ) {
292299 # reset to NULL if only "by" was specified
293300 if (missing(data ) && ! missing(by )) {
294301 data <- NULL
@@ -309,15 +316,17 @@ estimate_link <- function(model,
309316
310317# ' @rdname estimate_expectation
311318# ' @export
312- estimate_prediction <- function (model ,
313- data = NULL ,
314- by = NULL ,
315- predict = " prediction" ,
316- ci = 0.95 ,
317- transform = NULL ,
318- iterations = NULL ,
319- keep_iterations = FALSE ,
320- ... ) {
319+ estimate_prediction <- function (
320+ model ,
321+ data = NULL ,
322+ by = NULL ,
323+ predict = " prediction" ,
324+ ci = 0.95 ,
325+ transform = NULL ,
326+ iterations = NULL ,
327+ keep_iterations = FALSE ,
328+ ...
329+ ) {
321330 .estimate_predicted(
322331 model ,
323332 data = data ,
@@ -333,15 +342,17 @@ estimate_prediction <- function(model,
333342
334343# ' @rdname estimate_expectation
335344# ' @export
336- estimate_relation <- function (model ,
337- data = " grid" ,
338- by = NULL ,
339- predict = " expectation" ,
340- ci = 0.95 ,
341- transform = NULL ,
342- iterations = NULL ,
343- keep_iterations = FALSE ,
344- ... ) {
345+ estimate_relation <- function (
346+ model ,
347+ data = " grid" ,
348+ by = NULL ,
349+ predict = " expectation" ,
350+ ci = 0.95 ,
351+ transform = NULL ,
352+ iterations = NULL ,
353+ keep_iterations = FALSE ,
354+ ...
355+ ) {
345356 # reset to NULL if only "by" was specified
346357 if (missing(data ) && ! missing(by )) {
347358 data <- NULL
@@ -364,15 +375,17 @@ estimate_relation <- function(model,
364375# Internal ----------------------------------------------------------------
365376
366377# ' @keywords internal
367- .estimate_predicted <- function (model ,
368- data = " grid" ,
369- by = NULL ,
370- predict = " expectation" ,
371- ci = 0.95 ,
372- transform = NULL ,
373- iterations = NULL ,
374- keep_iterations = FALSE ,
375- ... ) {
378+ .estimate_predicted <- function (
379+ model ,
380+ data = " grid" ,
381+ by = NULL ,
382+ predict = " expectation" ,
383+ ci = 0.95 ,
384+ transform = NULL ,
385+ iterations = NULL ,
386+ keep_iterations = FALSE ,
387+ ...
388+ ) {
376389 # return early for htest
377390 if (inherits(model , " htest" )) {
378391 return (insight :: get_predicted(model , ... ))
@@ -384,7 +397,14 @@ estimate_relation <- function(model,
384397 }
385398
386399 # keep_iterations cannot be larger than interations
387- if (! is.null(keep_iterations ) && ! is.null(iterations ) && is.numeric(keep_iterations ) && is.numeric(iterations ) && keep_iterations > iterations ) { # nolint
400+ if (
401+ ! is.null(keep_iterations ) &&
402+ ! is.null(iterations ) &&
403+ is.numeric(keep_iterations ) &&
404+ is.numeric(iterations ) &&
405+ keep_iterations > iterations
406+ ) {
407+ # nolint
388408 insight :: format_error(" `keep_iterations` cannot be larger than `iterations`." )
389409 }
390410
@@ -450,7 +470,12 @@ estimate_relation <- function(model,
450470 data <- model_data
451471 } else if (! is.data.frame(data )) {
452472 if (is_grid ) {
453- data <- insight :: get_datagrid(model , reference = model_data , include_response = is_nullmodel , ... )
473+ data <- insight :: get_datagrid(
474+ model ,
475+ reference = model_data ,
476+ include_response = is_nullmodel ,
477+ ...
478+ )
454479 } else {
455480 insight :: format_error(
456481 " The `data` argument must either NULL, \" grid\" or another data frame."
@@ -462,7 +487,12 @@ estimate_relation <- function(model,
462487 grid_specs <- attributes(data )
463488
464489 # Get response for later residuals -------------
465- if (! is.null(model_response ) && length(model_response ) == 1 && model_response %in% names(data )) { # nolint
490+ if (
491+ ! is.null(model_response ) &&
492+ length(model_response ) == 1 &&
493+ model_response %in% names(data )
494+ ) {
495+ # nolint
466496 response <- data [[model_response ]]
467497 } else {
468498 response <- NULL
@@ -492,7 +522,10 @@ estimate_relation <- function(model,
492522 )
493523
494524 # for predicting grouplevel random effects, add "allow.new.levels"
495- if (! is.null(grouplevel_effects ) && any(grouplevel_effects %in% grid_specs $ at_spec $ varname )) {
525+ if (
526+ ! is.null(grouplevel_effects ) &&
527+ any(grouplevel_effects %in% grid_specs $ at_spec $ varname )
528+ ) {
496529 prediction_args $ allow.new.levels <- TRUE
497530 dots $ allow.new.levels <- NULL
498531 }
@@ -511,13 +544,22 @@ estimate_relation <- function(model,
511544 }
512545
513546 # remove response variable from data frame, as this variable is predicted
514- if (! is.null(model_response ) && length(model_response ) == 1 && model_response %in% colnames(out )) { # nolint
547+ if (
548+ ! is.null(model_response ) &&
549+ length(model_response ) == 1 &&
550+ model_response %in% colnames(out )
551+ ) {
552+ # nolint
515553 out [[model_response ]] <- NULL
516554 }
517555
518556 # keep row-column, but make sure it's integer
519557 if (" Row" %in% colnames(out )) {
520- out [[" Row" ]] <- insight :: format_value(out [[" Row" ]], protect_integers = TRUE )
558+ if (inherits(model , " nestedLogit" )) {
559+ out [[" Row" ]] <- NULL
560+ } else {
561+ out [[" Row" ]] <- insight :: format_value(out [[" Row" ]], protect_integers = TRUE )
562+ }
521563 }
522564
523565 # Add residuals
@@ -557,17 +599,21 @@ estimate_relation <- function(model,
557599 by = grid_specs $ at ,
558600 type = " predictions" ,
559601 model = model ,
560- info = c(
561- grid_specs ,
562- list (predict = predict ),
563- transform = ! is.null(transform )
564- )
602+ info = c(grid_specs , list (predict = predict ), transform = ! is.null(transform ))
565603 )
566604
567- attributes(out ) <- c(attributes(out ), grid_specs [! names(grid_specs ) %in% names(attributes(out ))])
605+ attributes(out ) <- c(
606+ attributes(out ),
607+ grid_specs [! names(grid_specs ) %in% names(attributes(out ))]
608+ )
568609
569610 # Class
570- class(out ) <- c(paste0(" estimate_" , predict ), " estimate_predicted" , " see_estimate_predicted" , class(out ))
611+ class(out ) <- c(
612+ paste0(" estimate_" , predict ),
613+ " estimate_predicted" ,
614+ " see_estimate_predicted" ,
615+ class(out )
616+ )
571617
572618 out
573619}
0 commit comments