1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
#' @include sanity_model.R
#' @rdname sanitize_model_specific
#' @export
sanitize_model_specific.brmsfit <- function(model, ...) {
insight::check_if_installed("collapse", minimum_version = "1.9.0")
# terms: brmsfit objects do not have terms immediately available
te <- tryCatch(
attr(stats::terms(stats::formula(model)$formula), "term.labels"),
error = function(e) NULL
)
if (any(grepl("^factor\\(", te))) {
stop(
"The `factor()` function cannot be used in the model formula of a `brmsfit` model. Please convert your variable to a factor before fitting the model, or use the `mo()` function to specify monotonic variables (see the `brms` vignette on monotonic variables).",
call. = FALSE
)
}
return(model)
}
#' @rdname get_coef
#' @export
get_coef.brmsfit <- function(model, ...) {
out <- insight::get_parameters(model)
out <- collapse::dapply(out, MARGIN = 2, FUN = collapse::fmedian)
return(out)
}
#' @include get_predict.R
#' @rdname get_predict
#' @export
get_predict.brmsfit <- function(
model,
newdata = insight::get_data(model),
type = "response",
...) {
checkmate::assert_choice(
type,
choices = c("response", "link", "prediction", "average")
)
if (type == "link") {
insight::check_if_installed("rstantools")
draws <- rstantools::posterior_linpred(
model,
newdata = newdata,
...
)
} else if (type == "response") {
insight::check_if_installed("rstantools")
draws <- rstantools::posterior_epred(
model,
newdata = newdata,
...
)
} else if (type == "prediction") {
insight::check_if_installed("rstantools")
draws <- rstantools::posterior_predict(
model,
newdata = newdata,
...
)
} else if (type == "average") {
insight::check_if_installed("brms")
draws <- brms::pp_average(
model,
newdata = newdata,
summary = FALSE,
...
)
}
# resp_subset sometimes causes dimension mismatch
if (length(dim(draws)) == 2 && nrow(newdata) != ncol(draws)) {
msg <- sprintf(
"Dimension mismatch: There are %s parameters in the posterior draws but %s observations in `newdata` (or the original dataset).",
ncol(draws),
nrow(newdata)
)
stop_sprintf(msg)
}
# 1d outcome
if (length(dim(draws)) == 2) {
med <- collapse::dapply(draws, MARGIN = 2, FUN = collapse::fmedian)
out <- data.table(
group = "main_marginaleffect",
estimate = med
)
# multi-dimensional outcome
} else if (length(dim(draws)) == 3) {
out <- apply(draws, c(2, 3), stats::median)
levnames <- dimnames(draws)[[3]]
if (is.null(levnames)) {
cols <- tryCatch(levels(insight::get_response(m)),
error = function(e) NULL)
if (is.null(cols)) {
colnames(out) <- seq_len(ncol(out))
} else {
colnames(out) <- cols
}
} else {
colnames(out) <- levnames
}
out <- data.table(
group = rep(colnames(out), each = nrow(out)),
estimate = c(out)
)
out$group <- group_to_factor(out$group, model)
} else {
stop(
"marginaleffects cannot extract posterior draws from this model. Please report this problem to the Bug tracker with a reporducible example: https://github.com/vincentarelbundock/marginaleffects/issues",
call. = FALSE
)
}
out <- add_rowid(out, newdata)
# group for multi-valued outcome
if (length(dim(draws)) == 3) {
draws <- lapply(seq_len(dim(draws)[3]), function(i) draws[, , i])
draws <- do.call("cbind", draws)
}
attr(out, "posterior_draws") <- t(draws)
return(out)
}
#' @include get_group_names.R
#' @rdname get_group_names
#' @export
get_group_names.brmsfit <- function(model, ...) {
if (!is.null(model$family) && "cumulative" %in% model$family) {
out <- unique(insight::get_response(model))
} else {
out <- "main_marginaleffect"
}
return(out)
}
#' @rdname get_vcov
#' @export
get_vcov.brmsfit <- function(model, vcov = NULL, ...) {
if (!is.null(vcov) && !is.logical(vcov)) {
warn_sprintf(
"The `vcov` argument is not supported for models of this class."
)
}
vcov <- sanitize_vcov(model, vcov)
return(NULL)
}
|