File: prune.R

package info (click to toggle)
r-cran-marginaleffects 0.32.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,784 kB
  • sloc: sh: 13; makefile: 8
file content (84 lines) | stat: -rw-r--r-- 2,654 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#' @importFrom generics prune
#' @export
generics::prune


#' Prune marginaleffects objects to reduce memory usage
#'
#' Remove large attributes from marginaleffects objects to reduce memory usage.
#' Warning: This will disable many useful post-processing features of `marginaleffects`
#' @param tree A marginaleffects object (predictions, comparisons, slopes, or hypotheses)
#' @param component A character string indicating which component to prune: "all" or "modeldata".
#' @param ... Unused
#' @return A pruned marginaleffects object
#' @details ...
#' @export
prune.marginaleffects <- function(tree, component = "all", ...) {
    checkmate::assert_choice(component, c("all", "modeldata"))

    mfx <- components(tree, "all")

    if (component == "all") {
        if (!is.null(mfx)) {
            mfx@model <- NULL
            mfx@newdata <- NULL
            mfx@modeldata <- NULL
            mfx@call <- NULL
            mfx@jacobian <- matrix()
        }
        essential_attrs <- c("names", "row.names", "class", "mfx")
        for (nm in setdiff(names(attributes(tree)), essential_attrs)) {
            attr(tree, nm) <- NULL
        }
        attr(tree, "lean") <- TRUE
    } else if (component == "modeldata") {
        # Do not prune modeldata for dbarts, mlr3, or tidymodels models
        if (inherits(mfx@model, c("bart", "Learner", "model_fit", "workflow"))) {
            attr(tree, "marginaleffects") <- mfx
            return(tree)
        }

        fml <- hush(insight::find_formula(mfx@model))
        fml <- unlist(lapply(fml, all.vars))

        keepers <- c(
            fml,
            names(mfx@variables),
            mfx@variable_names_datagrid,
            mfx@variable_names_response,
            mfx@variable_names_wts,
            mfx@variable_names_by,
            mfx@variable_names_by_hypothesis
        )

        # fixest-specific syntax: i.groupid and ~weights
        # before name lookup in modeldata
        if (inherits(mfx@model, "fixest")) {
            keepers <- gsub("^~|^i\\.", "", keepers)
        }

        # preserve order
        keepers <- unique(intersect(names(mfx@modeldata), unique(keepers)))

        if (inherits(mfx@modeldata, "data.table")) {
            mfx@modeldata <- mfx@modeldata[, ..keepers, drop = FALSE]
        } else {
            mfx@modeldata <- mfx@modeldata[, keepers, drop = FALSE]
        }
    }

    attr(tree, "marginaleffects") <- mfx
    return(tree)
}

#' @export
prune.predictions <- prune.marginaleffects

#' @export
prune.hypotheses <- prune.marginaleffects

#' @export
prune.slopes <- prune.marginaleffects

#' @export
prune.comparisons <- prune.marginaleffects