File: methods_brms.R

package info (click to toggle)
r-cran-marginaleffects 0.32.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,784 kB
  • sloc: sh: 13; makefile: 8
file content (155 lines) | stat: -rw-r--r-- 4,673 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#' @include sanity_model.R
#' @rdname sanitize_model_specific
#' @export
sanitize_model_specific.brmsfit <- function(model, ...) {
    insight::check_if_installed("collapse", minimum_version = "1.9.0")
    # terms: brmsfit objects do not have terms immediately available
    te <- tryCatch(
        attr(stats::terms(stats::formula(model)$formula), "term.labels"),
        error = function(e) NULL
    )
    if (any(grepl("^factor\\(", te))) {
        stop(
            "The `factor()` function cannot be used in the model formula of a `brmsfit` model. Please convert your variable to a factor before fitting the model, or use the `mo()` function to specify monotonic variables (see the `brms` vignette on monotonic variables).",
            call. = FALSE
        )
    }
    return(model)
}


#' @rdname get_coef
#' @export
get_coef.brmsfit <- function(model, ...) {
    out <- insight::get_parameters(model)
    out <- collapse::dapply(out, MARGIN = 2, FUN = collapse::fmedian)
    return(out)
}


#' @include get_predict.R
#' @rdname get_predict
#' @export
get_predict.brmsfit <- function(
    model,
    newdata = insight::get_data(model),
    type = "response",
    ...) {
    checkmate::assert_choice(
        type,
        choices = c("response", "link", "prediction", "average")
    )

    if (type == "link") {
        insight::check_if_installed("rstantools")
        draws <- rstantools::posterior_linpred(
            model,
            newdata = newdata,
            ...
        )
    } else if (type == "response") {
        insight::check_if_installed("rstantools")
        draws <- rstantools::posterior_epred(
            model,
            newdata = newdata,
            ...
        )
    } else if (type == "prediction") {
        insight::check_if_installed("rstantools")
        draws <- rstantools::posterior_predict(
            model,
            newdata = newdata,
            ...
        )
    } else if (type == "average") {
        insight::check_if_installed("brms")
        draws <- brms::pp_average(
            model,
            newdata = newdata,
            summary = FALSE,
            ...
        )
    }

    # resp_subset sometimes causes dimension mismatch
    if (length(dim(draws)) == 2 && nrow(newdata) != ncol(draws)) {
        msg <- sprintf(
            "Dimension mismatch: There are %s parameters in the posterior draws but %s observations in `newdata` (or the original dataset).",
            ncol(draws),
            nrow(newdata)
        )
        stop_sprintf(msg)
    }

    # 1d outcome
    if (length(dim(draws)) == 2) {
        med <- collapse::dapply(draws, MARGIN = 2, FUN = collapse::fmedian)
        out <- data.table(
            group = "main_marginaleffect",
            estimate = med
        )

        # multi-dimensional outcome
    } else if (length(dim(draws)) == 3) {
        out <- apply(draws, c(2, 3), stats::median)
        levnames <- dimnames(draws)[[3]]
        if (is.null(levnames)) {
            cols <- tryCatch(levels(insight::get_response(m)),
                error = function(e) NULL)
            if (is.null(cols)) {
                colnames(out) <- seq_len(ncol(out))
            } else {
                colnames(out) <- cols
            }
        } else {
            colnames(out) <- levnames
        }
        out <- data.table(
            group = rep(colnames(out), each = nrow(out)),
            estimate = c(out)
        )
        out$group <- group_to_factor(out$group, model)
    } else {
        stop(
            "marginaleffects cannot extract posterior draws from this model. Please report this problem to the Bug tracker with a reporducible example: https://github.com/vincentarelbundock/marginaleffects/issues",
            call. = FALSE
        )
    }

    out <- add_rowid(out, newdata)

    # group for multi-valued outcome
    if (length(dim(draws)) == 3) {
        draws <- lapply(seq_len(dim(draws)[3]), function(i) draws[, , i])
        draws <- do.call("cbind", draws)
    }
    attr(out, "posterior_draws") <- t(draws)

    return(out)
}


#' @include get_group_names.R
#' @rdname get_group_names
#' @export
get_group_names.brmsfit <- function(model, ...) {
    if (!is.null(model$family) && "cumulative" %in% model$family) {
        out <- unique(insight::get_response(model))
    } else {
        out <- "main_marginaleffect"
    }
    return(out)
}


#' @rdname get_vcov
#' @export
get_vcov.brmsfit <- function(model, vcov = NULL, ...) {
    if (!is.null(vcov) && !is.logical(vcov)) {
        warn_sprintf(
            "The `vcov` argument is not supported for models of this class."
        )
    }
    vcov <- sanitize_vcov(model, vcov)
    return(NULL)
}