File: methods_MASS.R

package info (click to toggle)
r-cran-marginaleffects 0.32.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,784 kB
  • sloc: sh: 13; makefile: 8
file content (120 lines) | stat: -rw-r--r-- 2,881 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#' @include get_coef.R
#' @rdname get_coef
#' @export
get_coef.polr <- function(model, ...) {
    out <- insight::get_parameters(model)
    out <- stats::setNames(out$Estimate, out$Parameter)
    names(out) <- gsub("Intercept: ", "", names(out))
    return(out)
}


#' @include set_coef.R
#' @rdname set_coef
#' @export
set_coef.polr <- function(model, coefs, ...) {
    # in basic model classes coefficients are named vector
    idx <- match(names(model$coefficients), names(coefs))
    model[["coefficients"]] <- coefs[idx]
    idx <- match(names(model$zeta), names(coefs))
    model[["zeta"]] <- coefs[idx]
    model
}


#' @include get_group_names.R
#' @rdname get_group_names
#' @export
get_group_names.polr <- function(model, ...) {
    resp <- insight::get_response(model)
    if (is.factor(resp)) {
        out <- levels(resp)
    } else {
        out <- unique(resp)
    }
    return(out)
}


#' @include get_predict.R
#' @rdname get_predict
#' @export
get_predict.polr <- function(
    model,
    newdata = insight::get_data(model),
    type = "probs",
    mfx = NULL,
    ...) {
    calling_function <- if (!is.null(mfx)) mfx@calling_function else "predictions"
    type <- sanitize_type(model, type, calling_function = calling_function)

    # hack: 1-row newdata returns a vector, so get_predict.default does
    # not learn about groups
    hack <- FALSE
    if (nrow(newdata) == 1) {
        hack <- TRUE
        newdata <- newdata[c(1, 1), , drop = FALSE]
    }

    out <- get_predict.default(model, newdata = newdata, type = type, ...)

    if (hack) {
        out <- out[seq_len(nrow(out)) %% 2 == 1, , drop = FALSE]
        newdata <- newdata[seq_len(nrow(newdata)) %% 2 == 1, , drop = FALSE]
    }

    out <- add_rowid(out, newdata)

    return(out)
}


#' @include set_coef.R
#' @rdname set_coef
#' @export
set_coef.glmmPQL <- function(model, coefs, ...) {
    model[["coefficients"]][["fixed"]][names(coefs)] <- coefs
    model
}


#' @rdname get_predict
#' @export
get_predict.glmmPQL <- function(
    model,
    newdata = insight::get_data(model),
    type = "response",
    mfx = NULL,
    ...) {
    out <- stats::predict(model, newdata = newdata, type = type, ...)
    out <- data.table(estimate = out)
    out <- add_rowid(out, newdata)
    return(out)
}


#' @rdname get_vcov
#' @export
get_vcov.lda <- function(model, ...) {
    return(NULL)
}



#' @rdname get_predict
#' @export
get_predict.lda <- function(
    model,
    newdata = insight::get_data(model),
    type = "class",
    ...) {
    out <- stats::predict(model, newdata = newdata)
    if (type == "class") {
        out <- data.table(estimate = out$class)
    } else if (type == "posterior") {
        out <- data.table::melt(data.table::data.table(out$posterior),
            variable.name = "group", value.name = "estimate")
    }
    out <- add_rowid(out, newdata)
    return(out)
}