File: hypothesis_function.R

package info (click to toggle)
r-cran-marginaleffects 0.32.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,784 kB
  • sloc: sh: 13; makefile: 8
file content (58 lines) | stat: -rw-r--r-- 2,015 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
hypothesis_function <- function(x, newdata, hypothesis, by) {
    draws <- attr(x, "posterior_draws")
    if (!is.null(draws)) {
        msg <- "The `hypothesis` argument does not support function for models with draws. You can use `get_draws()` to extract draws and manipulate them directly instead."
        stop_sprintf(msg)
    }

    if ("rowid" %in% colnames(x) && "rowid" %in% colnames(newdata)) {
        x <- merge(
            x,
            newdata,
            all.x = TRUE,
            by = intersect(colnames(x), colnames(newdata))
        )
    } else if (isTRUE(nrow(x) == nrow(newdata))) {
        x <- cbind(x, newdata)
    }

    attr(x, "variables_datagrid") <- attr(newdata, "variables_datagrid")

    argnames <- names(formals(hypothesis))
    if (!"x" %in% argnames) {
        stop_sprintf(
            "The `hypothesis` function must accept an `x` argument."
        )
    }
    if (!all(argnames %in% c("x", "draws"))) {
        msg <- "The allowable arguments for the `hypothesis` function are: `x` and `draws`"
        stop_sprintf(msg)
    }
    args <- list(x = x, newdata = newdata, by = by, draws = draws)
    args <- args[names(args) %in% argnames]
    out <- do.call(hypothesis, args)

    # sanity
    msg <- "The `hypothesis` argument function must return a data frame with `term` (or `hypothesis`) and `estimate` columns."
    if (inherits(out, "data.frame")) {
        if (
            !all(c("term", "estimate") %in% colnames(out)) &&
                !all(c("hypothesis", "estimate") %in% colnames(out))
        ) {
            stop_sprintf(msg)
        }
    } else if (isTRUE(checkmate::check_numeric(out))) {
        if (
            isTRUE(checkmate::check_data_frame(x, nrows = length(out))) &&
                "term" %in% colnames(out)
        ) {
            out <- data.frame(term = out$term, estimate = out)
        } else {
            out <- data.frame(term = seq_along(out), estimate = out)
        }
    } else {
        stop_sprintf(msg)
    }

    return(out)
}