File: methods_tidymodels.R

package info (click to toggle)
r-cran-marginaleffects 0.32.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,784 kB
  • sloc: sh: 13; makefile: 8
file content (111 lines) | stat: -rw-r--r-- 2,600 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
supported_engine <- function(x) {
    insight::check_if_installed("parsnip")
    tmp <- parsnip::extract_fit_engine(x)
    return(TRUE)
}


#' @include set_coef.R
#' @rdname set_coef
#' @export
set_coef.model_fit <- function(model, coefs, ...) {
    if (!"fit" %in% names(model)) {
        return(model)
    }

    model$fit <- set_coef(model$fit, coefs, ...)

    return(model)
}


#' @include set_coef.R
#' @rdname set_coef
#' @export
set_coef.workflow <- function(model, coefs, ...) {
    if ("fit" %in% names(model) && "fit" %in% names(model$fit)) {
        model$fit$fit <- set_coef(model$fit$fit, coefs, ...)
    }
    return(model)
}


#' @include get_predict.R
#' @rdname get_predict
#' @keywords internal
#' @export
get_predict.model_fit <- function(model, newdata, type = NULL, ...) {
    out <- stats::predict(model, new_data = newdata, type = type)

    if (type == "numeric") {
        v <- intersect(c(".pred", ".pred_res"), colnames(out))[1]
        out <- data.table(estimate = out[[v]])
    } else if (type == "class") {
        out <- data.table(estimate = out[[".pred_class"]])
    } else if (type == "prob") {
        colnames(out) <- substr(colnames(out), 7, nchar(colnames(out)))
        out$marginaleffects_internal_id <- seq_len(nrow(out))
        out <- data.table::melt(
            as.data.table(out),
            id.vars = "marginaleffects_internal_id",
            variable.name = "group",
            value.name = "estimate"
        )
        out$marginaleffects_internal_id <- NULL
    }
    out <- add_rowid(out, newdata)

    return(out)
}


#' @include get_predict.R
#' @rdname get_predict
#' @keywords internal
#' @export
get_predict.workflow <- get_predict.model_fit


#' @include get_vcov.R
#' @rdname get_vcov
#' @keywords internal
#' @export
get_vcov.model_fit <- function(model, vcov, type = NULL, ...) {
    if (isFALSE(vcov)) {
        return(FALSE)
    }

    if (isTRUE(type == "class")) {
        return(FALSE)
    }
    vcov <- sanitize_vcov(model, vcov)
    if (isTRUE(supported_engine(model))) {
        tmp <- parsnip::extract_fit_engine(model)
        out <- get_vcov(tmp)
    } else {
        out <- FALSE
    }
    return(out)
}


#' @include get_vcov.R
#' @rdname get_vcov
#' @keywords internal
#' @export
get_vcov.workflow <- get_vcov.model_fit


#' @include get_coef.R
#' @rdname get_coef
#' @keywords internal
#' @export
get_coef.workflow <- function(model, ...) {
    if (isTRUE(supported_engine(model))) {
        tmp <- parsnip::extract_fit_engine(model)
        out <- get_coef(tmp)
    } else {
        out <- NULL
    }
    return(out)
}